diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2340959505ae7c7512a6be9006fdad9fcc899fdd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/converter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/converter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c885a4ae94a4713d027cabebe397430cecde7c1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/converter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1e75c6771ae28a02da8ff6f7e952a186407087 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0ac7c05032d3430a9aea20e450b51348ebda4f3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6455d3b4ac71a44b405397cfb380895e911fcda4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/tools.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/tools.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e5a225a8bbb90df44032570df9466963ee160fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/tools.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0920f166ce5cef7f01afebc35613f3b3e71d0cac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9a45ccf6c6039f8899a0f7f62a5b0947dd993e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a16295b8701377da66413563e64a866dbb2ea33 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d404b7a0c87ac9f038d64b9cb51c31a7a76ce0f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af9f4b75dcb15e923b248eebe3c16c41f3205a93 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87b41c7fcdf28a31dff8cafe77961299d49d190d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dd25105b2973b5751da77909e5928a21a9dcb60 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..077c7f93142d34c39c98ebe22b7afcfd48bfb16e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a7d62f28b03d5f28ef8c7e213527d035971283e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39af0611a5a56e62c3f3be91dfceb0bb3ceb062e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db9775acd5a9f5752da444aa8ab3d1d670a82d6f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c619d972e1d23c4c26617b41b021cb32c5dff0d0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d36d0286f918c780e3611b628354ef3a28055b8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6033b179c861236e69cf0863cb40842f4825836 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1845f4440dd5dc7797a876e1e073048e70b0e1ba Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b644d1933a896133ffcc4ed9f28c43b593c5238 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59540dcea59e938bf67d4cc72b142313690fb42c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0752268ca3ba9e769ec78da0043630065171a6d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9ce2ac03c23600c86ff02e38a2a4bfeefef9e2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py @@ -0,0 +1 @@ +from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaab29ff6d9a413c295609f2275f9d0ed281d9e5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09ad29abedea5f34e2d0fb5028f6f17d22f5531f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40ca13276f1bfc549268eea239515b5fa57fb20f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbbcc63464b7b412b2b9a349906895d967128e53 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79f71d6d219009a7327e5810d888799e4b2d8ce4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0931e83c6cf57ef35ceb27d5ccad28353cd88a5a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5eb7291228e44382cd908e10292ba152f8eacd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..852def15ef5219e1ad54536a590aa03186bfaf1b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4ba888c4cf597d65f572d7aad9020f024754cc4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae01d460879f3a41a03e5d9b6d4732b7c4f9d665 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c7864c580e9e7017befaba8a21b5349a9f962b7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c39ac71f3827d974642ac0624e74ad9441af20d5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b0f17766a1b10dbea9e7a327890e6d90a0842d4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/_node_metadata_hook.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/_node_metadata_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd87b546da8df08522f1c237bab44e9668b4b47 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/_node_metadata_hook.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch +from torch.fx.graph_module import GraphModule + + +_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook" + + +def _node_metadata_hook(node: torch.fx.Node, stack_trace: str) -> None: + """ + Hook for adding the appropriate metadata to nodes that are created during a + pass using graph.create_node. An example of how to use it: + + ``` + with _set_node_metadata_hook(gm, + functools.partial(_node_metadata_hook, stack_trace="file") + ): + pass(gm) + ``` + + This hook should not work for all generic cases -- specifically it assumes + that nodes being added are only call_function nodes, and copies over the + first argument node's nn_module_stack. + """ + assert node.op == "call_function" and callable(node.target) + + arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)] + assert len(arg_meta) >= 1 + arg_meta = arg_meta[0] + + if ( + isinstance(node.target, torch._ops.OpOverload) + and len(node.target._schema.returns) == 0 + ): + node.meta["val"] = None + else: + fake_args = [ + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ] + fake_res = node.target(*fake_args) + node.meta["val"] = fake_res + + node.meta["stack_trace"] = stack_trace + node.meta["nn_module_stack"] = arg_meta.get( + "nn_module_stack", + { + _EMPTY_NN_MODULE_STACK_KEY: ( + _EMPTY_NN_MODULE_STACK_KEY, + _EMPTY_NN_MODULE_STACK_KEY, + ) + }, + ) + node.meta["torch_fn"] = ( + f"{node.target.__name__}_0", + f"{node.target.__class__.__name__}.{node.target.__name__}", + ) + + +@contextlib.contextmanager +def _set_node_metadata_hook(gm: torch.fx.GraphModule, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "node_metadata_hook must be a callable." + + # Add the hook to all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._register_create_node_hook(f) + try: + yield + finally: + # Restore hook for all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._unregister_create_node_hook(f) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ed5931a74fc21f71d35a5cf23983846b2be449 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -0,0 +1,227 @@ +# mypy: allow-untyped-defs +import math +import operator +import traceback +from functools import partial +from typing import Callable, Dict, List, NamedTuple, Set + +import sympy + +import torch +import torch.fx +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._sympy.numbers import int_oo +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +__all__ = ["InputDim"] + + +class InputDim(NamedTuple): + input_name: str + dim: int + + +def _convert_to_int(val): + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return math.inf + if val in (-sympy.oo, -int_oo): + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError( + "Export constraints cannot be non-integer expressions" + ) + + +def _convert_range_to_int(range: ValueRanges): + assert isinstance(range, ValueRanges) + min_val = _convert_to_int(range.lower) + max_val = _convert_to_int(range.upper) + return min_val, max_val + + +class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): + def __init__( + self, + range_constraints: Dict[sympy.Symbol, ValueRanges], + ): + super().__init__() + self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints + self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set() + self.counter = 0 + + def _assert_range_constraint(self, node, lower, upper, assert_msg): + last_node = node + if lower > -math.inf: + last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg) + + if upper < math.inf: + last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg) + + def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): + """ + Inserts assert_async call_function nodes in the graph. This function is + called **during** the interpreter-based pass. + """ + self.counter += 1 + graph = last_node.graph + with graph.inserting_after(last_node): + cmp = graph.call_function(op, (lower, upper), {}) + with graph.inserting_after(cmp): + cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {}) + with graph.inserting_after(cmp_tensor): + assert_async = graph.call_function( + torch.ops.aten._assert_async.msg, + (cmp_tensor, assert_msg), + {}, + ) + return assert_async + + def call(self, graph_module) -> PassResult: + self.existing_inline_assertions = _get_existing_inline_assertions( + graph_module, self.range_constraints + ) + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if "val" not in node.meta: + continue + + val = node.meta["val"] + # In general, we may have to deal the case such as: ret[1].shape[0]. + # We need first find out what symbols require assertion, then we need to follow the path + # from ret to the symbol, construct the proxies along the way and construct the messages + # piece-wise at the same time. + # + # We use post-order traversal to collect all the proxies callbacks needed, construct + # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. + # We need the callbacks because, in order to call the function to create a proxy for shape[0], we + # need the proxy for shape, which further requires the proxy for ret[1], etc. + + def add_assertions(val): + call_backs: List[Callable] = [] + messages: List[str] = [] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + symbol = val.node.expr + if symbol in self.existing_inline_assertions: + return call_backs, messages + if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol): + if symbol in self._asserts_generated_unbacked_symbols: + return call_backs, messages + # We only care about unbacked symints for these inline + # constraints, which are prefixed with 'u' + constraint = self.range_constraints[symbol] + min_val, max_val = _convert_range_to_int(constraint) + assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." + call_backs.append( + partial(self._assert_range_constraint, lower=min_val, upper=max_val) + ) + messages.append(assert_msg) + self._asserts_generated_unbacked_symbols.add(symbol) + + elif isinstance(val, torch.Tensor): + for i, sym in enumerate(val.shape): + cbs, msgs = add_assertions(sym) + for cb, msg in zip(cbs, msgs): + def sym_size_cb(node, assert_msg, dim): + with node.graph.inserting_after(node): + dim_node = module.graph.call_function( + torch.ops.aten.sym_size.int, + (node, dim), + {}, + ) + cb(node=dim_node, assert_msg=assert_msg) + call_backs.append(partial(sym_size_cb, dim=i)) + messages.append(f".shape[{i}]" + msg) + return call_backs, messages + + callbacks, messages = add_assertions(val) + for cb, msg in zip(callbacks, messages): + cb(node=node, assert_msg=f"{node}" + msg) + + module.recompile() + + # Sometimes this pass would return a wrong graph where we have mismatched + # node names in signature. Before we fix it, let's just skip it. + if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass: + return PassResult(graph_module, False) + + # Populate the stack trace with dummy vals to respect IR + for node in graph_module.graph.nodes: + if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]: + node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) + return PassResult(graph_module, True) + + +def _get_existing_inline_assertions( + graph_module: torch.fx.GraphModule, + range_constraints: Dict[sympy.Symbol, ValueRanges], +) -> Dict[sympy.Symbol, ValueRanges]: + existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {} + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + # Find all the existing inline assertions. They will look something like: + # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {}) + # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {}) + # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {}) + for node in module.graph.nodes: + if node.target != torch.ops.aten._assert_scalar.default: + continue + + compare_arg = node.args[0] + if not ( + isinstance(compare_arg, torch.fx.Node) and + compare_arg.op == "call_function" and + compare_arg.target in (operator.le, operator.ge) and + len(compare_arg.args) == 2 + ): + continue + + compare_op = compare_arg.target + lhs, rhs = compare_arg.args + + def maybe_get_symint(x): + if ( + isinstance(x, torch.fx.Node) and + "val" in x.meta and + isinstance(x.meta["val"], torch.SymInt) + ): + return x.meta["val"].node.expr + return x + + lhs = maybe_get_symint(lhs) + rhs = maybe_get_symint(rhs) + + if compare_op == operator.ge: + lhs, rhs = rhs, lhs + + if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int): + symint = lhs + scalar = rhs + elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int): + symint = rhs + scalar = lhs + else: + continue + + if symint not in range_constraints: + raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}") + + previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf)) + + if symint is lhs: + bounds = ValueRanges(-math.inf, scalar) + else: + bounds = ValueRanges(scalar, math.inf) + existing_inline_assertions[symint] = previous_range & bounds + + return existing_inline_assertions diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..c89d2216632fa5dfde608ec5f4b857195bcb19ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch.export.exported_program import ConstantArgument, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +__all__ = ["CollectTracepointsPass"] + + +class CollectTracepointsPass(PassBase): + """ + Performs constant folding and constant propagation. + """ + + def __init__(self, specs, sig) -> None: + super().__init__() + self.specs = specs + self.sig = sig + + def call(self, gm): + def get_arg_spec(arg): + if isinstance(arg, torch.fx.Node): + if isinstance(arg.meta.get("val"), torch.Tensor): + return TensorArgument(name=arg.name) + else: + raise AssertionError( + "Symint input is not implemented yet for submodule call signature." + ) + else: + return ConstantArgument(name="", value=arg) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + nn_module_stack = None + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_outputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_inputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + nn_module_stack = None + for node in reversed(module.graph.nodes): + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_inputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_outputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + for i, arg in enumerate(node.args): + kind = node.kwargs["kind"] + if kind == "module_call_inputs": + self.specs[node.kwargs["path"]].inputs.append( + get_arg_spec(arg) + ) + elif kind == "module_call_outputs": + self.specs[node.kwargs["path"]].outputs.append( + get_arg_spec(arg) + ) + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + if isinstance(arg, torch.fx.Node): + for user in node.users: + assert user.op == "call_function" + assert user.target == operator.getitem + assert isinstance(user.args[1], int) + if user.args[1] == i: + user.replace_all_uses_with(arg) + self.sig.replace_all_uses(user.name, arg.name) + break + users = list(node.users) + for user in users: + assert len(user.users) == 0 + gm.graph.erase_node(user) + gm.graph.erase_node(node) + return PassResult(gm, True) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/constant_folding.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..b1491ca5d4794647f3cc348dc4bcf4c59134031a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/constant_folding.py @@ -0,0 +1,299 @@ +# mypy: allow-untyped-defs +import collections +from collections import defaultdict +from typing import Any, Callable, Dict, Optional + +import torch +import torch.utils._pytree as pytree + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm, + skip_constructors=False, + ): + super().__init__(gm) + self.node_replacements: Dict[torch.fx.Node, Any] = {} + self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.node.Node): + if ( + node.target == torch.ops.prims.convert_element_type.default + and node.args[0].op == "get_attr" # type: ignore[union-attr] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) + + for node in reversed(self.module.graph.nodes): + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) == type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): + env = {} + for n in self.module.graph.find_nodes(op="placeholder"): + env[n] = self.unknown_value + return super().run(initial_env=env) + + +def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + # Get all attr users by looking up the graph instead from node.users, because in this case + # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. + + # opcode name target args kwargs + # ------------- ------------------- ---------------- --------------------------- -------- + # placeholder arg0_1 arg0 () {} + # get_attr _tensor_constant0 state () {} + # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} + # get_attr _tensor_constant0_1 state () {} + # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} + # output output output ([add],) {} + + get_attr_node_users = defaultdict(list) + for node in gm.graph.nodes: + if node.op == "get_attr": + get_attr_node_users[node.target].extend(node.users.keys()) + for node in gm.graph.find_nodes(op="get_attr"): + if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag(gm: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcf5adaca5b0b478db87e71633f5136b54969b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -0,0 +1,94 @@ +import copy +from typing import Dict, Optional, Tuple, List + +import torch +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._ops import OpOverload + +aten = torch.ops.aten + +_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = { + aten.sym_constrain_range.default: aten._functional_sym_constrain_range, + aten._assert_async.msg: aten._functional_assert_async.msg, +} + + +class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Functionalize ops with side effect in graph module by replacing the op with + functional version of it. A new dependency token (`dep_token`) will be + created and propagated through functional ops to output. + For example: + ``` + def f(x): + sym_constrain_range(x.shape[0], min=1, max=3) + return x.add(3) + ``` + Will be transformed to: + ``` + def f(x): + dep_token0 = _make_dep_token() + dep_token1 = _functional_sym_constrain_range( + x.shape[0], min=1, max=3, dep_token=dep_token0 + ) + + return x.add(3), dep_token1 + ``` + """ + + def __init__(self) -> None: + super().__init__() + self._dep_token: Optional[ProxyValue] = None + self._next_dep_token_index: Optional[int] = None + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Early return if no non-functional assertions. + if not any( + n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS + for n in graph_module.graph.nodes + ): + return PassResult(graph_module=graph_module, modified=False) + + gm = copy.deepcopy(graph_module) + self._dep_token = None + self._next_dep_token_index = None + return super().call(gm) + + def call_operator( + self, + op: OpOverload, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: + return super().call_operator(op, args, kwargs, meta) + + if self._dep_token is None: + self._dep_token = super().call_operator( + aten._make_dep_token, + args=(), + kwargs={}, + meta=self._create_dummy_node_metadata(), + ) + self._dep_token.node.name = "dep_token0" + self._next_dep_token_index = 1 + + self._dep_token = super().call_operator( + _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op], + args=args, + kwargs={**kwargs, "dep_token": self._dep_token}, + meta=meta, + ) + assert self._next_dep_token_index is not None + self._dep_token.node.name = f"dep_token{self._next_dep_token_index}" + self._next_dep_token_index += 1 + + return self._dep_token + + def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + assert self._dep_token is not None + + return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..3d1ce32cda31776015a448b59bdadfd6fdd363a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py @@ -0,0 +1,318 @@ +# mypy: allow-untyped-defs +import collections +import warnings +from typing import Any, Dict, List, Union + +import torch +from torch._export.verifier import SpecViolationError +from torch._guards import detect_fake_mode +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.export.exported_program import ( + ArgumentSpec, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) + + +class ConstantAttrMap(collections.abc.MutableMapping): + """A mapping class that understands how to use module constants (tensors, + ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally, + but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to + the same underlying value (but we guarantee that they will `hash()` to the same value + if that's the case). + """ + + def __init__(self) -> None: + # Underlying dict that we use to implement this mapping. + self._constant_attrs: Dict[ + Union[int, torch.Tensor, FakeScriptObject], List[Any] + ] = {} + # Map from the hash(ScriptObject) to the ScriptObject itself. Used for + # APIs like `__iter__` that should look like they're returning the + # original ScriptObjects. + self._script_object_map: Dict[int, torch.ScriptObject] = {} + + def __getitem__( + self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] + ) -> Any: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject)) + return self._constant_attrs[real_key] + + def __setitem__(self, key: Union[torch.Tensor, torch.ScriptObject], value): + # we shouldn't actually call this, should go to add() instead to handle aliasing + raise NotImplementedError( + """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead. +The same key can be mapped to multiple values, for handling constant aliasing.""" + ) + + def add( + self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], value: Any + ) -> None: + if isinstance(key, torch.ScriptObject): + if hash(key) not in self._constant_attrs: + self._constant_attrs[hash(key)] = [] + self._constant_attrs[hash(key)].append(value) + self._script_object_map[hash(key)] = key + elif isinstance(key, (torch.Tensor, FakeScriptObject)): + if key not in self._constant_attrs: + self._constant_attrs[key] = [] + self._constant_attrs[key].append(value) + else: + raise TypeError( + f"Expected key to be a tensor or ScriptObject, got {type(key)}" + ) + + def __delitem__(self, key): + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + + del self._constant_attrs[real_key] + + def __iter__(self): + for key in self._constant_attrs: + if isinstance(key, int): + yield self._script_object_map[key] + else: + yield key + + def __len__(self): + return len(self._constant_attrs) + + def __contains__(self, key: object) -> bool: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + return real_key in self._constant_attrs + + +def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str: + # The FQN of the constant tensor in the state dict should + # correspond to the module where the constant tensor was + # originally used. + if len(node.meta["nn_module_stack"]) == 0: + return constant_name + parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0] + if len(parent_fqn) > 0: + return f"{parent_fqn}.{constant_name}" + else: + return constant_name + + +def _get_first_fqn( + const_attrs: ConstantAttrMap, + key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], +) -> Any: + fqns = const_attrs.get(key) + return fqns[0] if fqns else None + + +def lift_constants_pass( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]: + """ + Takes a graph module, graph signature, and modifies them implace to lift any + constants (tensors or custom classes) as inputs to the graph. Returns a + dictionary of names to constants. + + Arguments: + gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift. + graph_signature (ExportGraphSignature): This graph signature will be + mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs. + constant_attrs (ConstantAttr): A mapping from a constant value to its + fully-qualified path in `gm`. This is used to maintain consistent + location of constants between the original module and the exported + version. + + Returns: + A dictionary of fqn => constant value. + """ + all_constants: Dict[ + str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] + ] = {} + + inputs = graph_signature.input_specs + num_custom_obj = sum( + input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs + ) + num_tensor_constants = sum( + input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs + ) + + fake_mode = detect_fake_mode( + tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") + ) + + first_user_input_loc, first_user_input = 0, None + for node in gm.graph.nodes: + if node.op == "placeholder" and node.name in graph_signature.user_inputs: + first_user_input = node + break + first_user_input_loc += 1 + + lifted_objs = ConstantAttrMap() + for node in gm.graph.nodes: + if node.op == "get_attr": + constant_val = getattr(gm, node.target) + if constant_val in lifted_objs: + # We already lifted this constant elsewhere. Just rewrite uses + # of this get_attr to point to the already-existing placeholder + # node. + const_placeholder_node = _get_first_fqn(lifted_objs, constant_val) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + continue + + # For ScriptObject, Tensor and FakeScriptObject constants: + # First check if the constant was an attribute on some module by + # consulting `constant_attrs` map. If it is, use the fqn that keeps + # its location consistent with the eager module. + # + # If it's not in the `constant_attrs` map, that means it's an inline + # constant (e.g. x + torch.tensor(0)), and thus did not have a + # specific location in the eager module. In that case, just generate + # some name and attach it to the module in which it was used. + if isinstance(constant_val, (torch.ScriptObject, FakeScriptObject)): + constant_kind = InputKind.CUSTOM_OBJ + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_custom_{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + num_custom_obj += 1 + elif isinstance(constant_val, torch.Tensor): + # Remove the parameterness of constant_val + if isinstance(constant_val, torch.nn.Parameter): + warnings.warn( + f"{node.target} created when tracing {node.meta['stack_trace']} is a parameter. But" + f"it's not registered with register_parameter(). export will treat it as a constant tensor" + ) + # We get the real data out of the parameter by disabling the surrounding fake mode. + with unset_fake_temporarily(): + constant_val = constant_val.data + constant_kind = InputKind.CONSTANT_TENSOR + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_tensor_{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + num_tensor_constants += 1 + elif isinstance(constant_val, torch.fx.GraphModule): + continue + elif "LoweredBackendModule" in type(constant_val).__name__: + continue + else: + raise SpecViolationError( + f"getattr node {node} referencing unsupported type {type(constant_val)}" + ) + + with gm.graph.inserting_before(first_user_input): + # Insert the constant node before the first user input + const_placeholder_node = gm.graph.placeholder(constant_name) + # match target name with its node name in case there is name collision + # and suffix is added to node name in fx + const_placeholder_node.target = const_placeholder_node.name + + for k, v in node.meta.items(): + const_placeholder_node.meta[k] = v + + # Once the FQN has been used, remove nn_module_stack, stack_trace + const_placeholder_node.meta.pop("nn_module_stack") + const_placeholder_node.meta.pop("stack_trace", None) + + input_spec_arg: ArgumentSpec + if isinstance(constant_val, torch.Tensor): + if fake_mode is not None: + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + constant_val, static_shapes=True + ) + const_placeholder_node.meta["val"].constant = constant_val + else: + const_placeholder_node.meta["val"] = constant_val + input_spec_arg = TensorArgument(name=const_placeholder_node.name) + elif isinstance(constant_val, torch._C.ScriptObject): + class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn=class_fqn + ) + elif isinstance(constant_val, FakeScriptObject): + class_fqn = constant_val.script_class_name + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn, constant_val + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, + class_fqn=class_fqn, + fake_val=constant_val, + ) + else: + raise SpecViolationError( + f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}" + ) + + lifted_objs.add(constant_val, const_placeholder_node) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + + # Add the constant as a buffer to the graph signature + graph_signature.input_specs.insert( + first_user_input_loc, + InputSpec( + kind=constant_kind, + arg=input_spec_arg, + target=constant_fqn, + ), + ) + if constant_val in constant_attrs: + for fqn in constant_attrs[constant_val]: + all_constants[fqn] = constant_val + else: + all_constants[constant_fqn] = constant_val + first_user_input_loc += 1 + + return all_constants + + +def rewrite_script_object_meta( + gm: torch.fx.GraphModule, +) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]: + """When tracing, we produce a graph with FakeScriptObject in the + meta["val"]. + + For now, we rewrie meta["val"] to be a placeholder CustomObjArgument + """ + constants: Dict[ + str, + Union[ + torch.Tensor, + torch.ScriptObject, + FakeScriptObject, + ], + ] = {} + for node in gm.graph.nodes: + if "val" not in node.meta: + continue + + old_meta = node.meta["val"] + + if isinstance(old_meta, torch.ScriptObject): + class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + elif isinstance(old_meta, FakeScriptObject): + class_fqn = old_meta.script_class_name # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn, old_meta) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + return constants diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py new file mode 100644 index 0000000000000000000000000000000000000000..a80b62d2765a87b0e20dc7614c6a353c86225d81 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class _RemoveRuntimeAssertionsPass(PassBase): + """ + Remove runtime assertions inserted by the + _AddRuntimeAssertionsForInlineConstraintsPass. + """ + + def call(self, graph_module) -> PassResult: + modified = False + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target == torch.ops.aten._assert_async.msg: + assert_async_node = node + if len(assert_async_node.users) > 0: + continue + module.graph.erase_node(assert_async_node) + # the upstream scalar_tensor <- {le, ge} <- sym_size + # linear chain of nodes of nodes is removed by the + # downstream dead code elimination + modified = True + return PassResult(graph_module, modified) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a15a3ef9c3b3f7465f4dc9cf8f6a2f32ea1b0884 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -0,0 +1,179 @@ +# mypy: allow-untyped-defs +from typing import List + +import torch +from torch._higher_order_ops.wrap import wrap_with_autocast + +from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +def _is_autocast_node(node: torch.fx.Node): + return ( + node + and node.op == "call_function" + and node.target + in [ + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + ] + ) + + +def _is_enter_autocast_node(node: torch.fx.Node): + return ( + node + and node.op == "call_function" + and node.target == torch.amp.autocast_mode._enter_autocast + ) + + +def _is_exit_autocast_node(node: torch.fx.Node): + return ( + node + and node.op == "call_function" + and node.target == torch.amp.autocast_mode._exit_autocast + ) + + +def _is_autocast_sub_mod(node: torch.fx.Node): + """ + Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`. + """ + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target == torch.amp.autocast_mode._enter_autocast + ): + # TODO: check if current auto-cast type is the same as the args of + # _enter_autocast. If so, return False, i.e. do not create a submodule. + return True + return False + + +def _check_valid_autocast_block(enter_autocast_node, exit_autocast_node): + assert _is_enter_autocast_node(enter_autocast_node) + assert _is_exit_autocast_node(exit_autocast_node) + assert exit_autocast_node.args[0] == enter_autocast_node + + +def _replace_with_hop(node: torch.fx.Node): + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node) + if len(autocast_nodes) > 0: + assert len(autocast_nodes) > 1 # need at least an enter node and an exist node + enter_autocast_node = autocast_nodes[0] + exit_autocast_node = autocast_nodes[-1] + _check_valid_autocast_block(enter_autocast_node, exit_autocast_node) + + _replace_with_hop_helper( + node, enter_autocast_node, _is_autocast_node, wrap_with_autocast + ) + sub_graph.erase_node(exit_autocast_node) + sub_graph.erase_node(enter_autocast_node) + + +def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + split_autocast creates a new graph module that splits the input graph module into multiple submodules + based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module. + + Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are splitted + into a submodule. Nested autocast regions are not splitted. + `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well. + + Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph + module. Nodes marked with the same number are grouped into the same submodule. + A # 0 + enter_autocast # 1 + B # 1 + exit_autocast # 1 + C # 2 + enter_autocast # 3 + D # 3 + exit_autocast # 3 + E # 4 + """ + enter_autocast_node_stack: List[torch.fx.Node] = [] + first_node_after_outer_most_exit: bool = False + + def node_call_back(node: torch.fx.Node): + nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit + if first_node_after_outer_most_exit or ( + len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node) + ): + assert len(enter_autocast_node_stack) == 0 + first_node_after_outer_most_exit = False + if _is_enter_autocast_node(node): + enter_autocast_node_stack.append(node) + return True + if _is_exit_autocast_node(node): + assert len(enter_autocast_node_stack) > 0 + last_enter_autocast_node = enter_autocast_node_stack.pop() + assert node.args[0] == last_enter_autocast_node + if len(enter_autocast_node_stack) == 0: + # next node should be in the next submodule since + # autocast block ends + first_node_after_outer_most_exit = True + return False + + return sequential_split(gm, node_call_back) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature +): + """ + Helper function for replace_autocast_with_hop_pass(). + Split the graph module into multiple subgraphs based on the autocast nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered + as a subgraph. + """ + need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # split_autocast returns a new graph module that could have different output + # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`. + new_gm = _split_autocast(gm) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): + if _is_autocast_sub_mod(node): + _replace_with_hop(node) + else: + assert node.op == "call_module" + assert isinstance(node.target, str) + node_inline_(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_autocast_with_hop_pass(gm: torch.fx.GraphModule, graph_signature): + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..47a93a035aa5389ce347a49d99ca6230484d3b16 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py @@ -0,0 +1,673 @@ +# mypy: allow-untyped-defs +import logging +import operator +from typing import List, Optional, Tuple, Union + +import torch +import torch.export._trace +from torch._ops import OpOverload +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel, + dequantize_per_tensor, + quantize_per_tensor, +) +from torch.ao.quantization.utils import calculate_qmin_qmax +from torch.fx.graph_module import _assign_attr + + +log = logging.getLogger(__name__) + +# Those values will need to be carried over multiple operators. +_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None +_SCALE: Optional[Union[float, torch.fx.Node]] = None +_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None + + +def int_to_valid_dtype(val: int) -> torch.dtype: + from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import. + + if isinstance(val, torch.dtype): + return val + dtype = _TORCH_ENUM_TO_DTYPE[val] + if dtype == torch.quint8: + return torch.uint8 + elif dtype == torch.qint8: + return torch.int8 + return dtype + + +def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node: + return gm.graph.call_function(int_to_valid_dtype, (val,)) + + +def insert_quantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + return gm.graph.call_function( + quantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + + +def get_dequantized( + val: torch.Tensor, + scale: Union[float, torch.Tensor], + zero_point: Union[float, torch.Tensor], + qmin: Union[float, int], + qmax: Union[float, int], + dtype: torch.dtype, + axis: Optional[int], + qscheme: Optional[torch.qscheme], +) -> torch.Tensor: + if qscheme is torch.per_tensor_affine: + return dequantize_per_tensor( + val, + scale, + zero_point, + qmin, + qmax, + dtype, + ) + elif qscheme is torch.per_channel_affine: + return dequantize_per_channel( + val, + scale, + zero_point, + axis, + qmin, + qmax, + dtype, + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def insert_dequantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + axis_node: Optional[Union[int, torch.fx.Node]], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + if qscheme is torch.per_tensor_affine: + return gm.graph.call_function( + dequantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + elif qscheme is torch.per_channel_affine: + return gm.graph.call_function( + dequantize_per_channel, + ( + val_node, + scale_node, + zero_point_node, + axis_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def get_qmin_qmax(dtype: torch.dtype) -> Tuple[Union[int, float], Union[int, float]]: + return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type] + + +def insert_qmin_qmax_node( + gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node] +) -> Tuple[torch.fx.Node, torch.fx.Node]: + q_min_max_node = gm.graph.call_function( + calculate_qmin_qmax, (None, None, False, dtype_node, False) + ) + qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0)) + qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1)) + return qmin_node, qmax_node + + +def get_script_object( + gm: torch.nn.Module, node: torch.fx.Node +) -> torch._C.ScriptObject: + assert isinstance(node, torch.fx.Node) + assert node.op == "get_attr" + attr_name = node.target + assert isinstance(attr_name, str) + + mod = gm + for attr in attr_name.split("."): + mod = getattr(mod, attr) + assert isinstance(mod, torch._C.ScriptObject) + return mod + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm: torch.fx.GraphModule, + param_node: torch.fx.Node, +) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]: + """Directly inline tensor from a get_attr fx node.""" + mod = get_script_object(gm, param_node) + w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined] + w_attr_name, b_attr_name = ( + f"dequantized_{param_node.target}_w", + f"dequantized_{param_node.target}_b", + ) + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm: torch.fx.GraphModule, + get_attr_to_weight_node: torch.fx.Node, + get_attr_to_bias_node: Optional[torch.fx.Node], +) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]: + assert isinstance(get_attr_to_weight_node.target, str) + w_qtensor = getattr(gm, get_attr_to_weight_node.target) + w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w" + + if get_attr_to_bias_node is not None: + assert isinstance(get_attr_to_bias_node.target, str) + b_qtensor = getattr(gm, get_attr_to_bias_node.target) + b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b" + else: + b_qtensor, b_attr_name = None, "" + + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node( + gm: torch.fx.GraphModule, + w_qtensor: torch.Tensor, + b_qtensor: Optional[torch.Tensor], + w_attr_name: str, + b_attr_name: str, +) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]: + w_tensor = get_tensor_from_qtensor(w_qtensor) + _assign_attr(w_tensor, gm, w_attr_name) + w_tensor_attr = gm.graph.get_attr(w_attr_name) + + if b_qtensor is not None: + b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False) + _assign_attr(b_tensor, gm, b_attr_name) + b_tensor_attr = gm.graph.get_attr(b_attr_name) + else: + b_tensor_attr = None + + return w_tensor_attr, b_tensor_attr + + +def get_tensor_from_qtensor( + qtensor: torch.Tensor, dequant: bool = True +) -> torch.Tensor: + # Manual conversion because qint8 is not used anymore. + if qtensor.dtype in [torch.qint8, torch.quint8]: + tensor = qtensor.int_repr() + else: + tensor = qtensor + + # Weights need dequantization with scaling and zero_point adjustment, but + # bias does not need that. + if dequant: + qscheme = qtensor.qscheme() + if qscheme == torch.per_channel_affine: + scale, zero_point, axis = ( + qtensor.q_per_channel_scales(), + qtensor.q_per_channel_zero_points(), + qtensor.q_per_channel_axis(), + ) + else: + scale, zero_point, axis = ( + qtensor.q_scale(), # type: ignore[assignment] + qtensor.q_zero_point(), # type: ignore[assignment] + None, + ) + dtype = tensor.dtype + qmin, qmax = get_qmin_qmax(dtype) + return get_dequantized( + tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme + ) + return tensor + + +def insert_fused_activation_node( + gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node +) -> torch.fx.Node: + if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]: + fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,)) + return fx_node + + +def _conv1d_op_with_squeeze( + inp: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, +) -> torch.Tensor: + # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze + # operations before and after the conv2d operation to match the dimension of weights. + # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950 + s_inp = torch.ops.aten.unsqueeze(inp, 2) + conv1d_res = torch.ops.aten.conv2d( + s_inp, + weight, + bias, + stride, + padding, + dilation, + groups, + ) + uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2) + return uns_conv1d_res + + +def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Conv specfic transformation function.""" + assert isinstance(node.target, torch._ops.OpOverload) + opname = node.target._opname + scale_node, zero_point_node = node.args[2], node.args[3] + + op_f = ( + torch.ops.aten.conv2d + if opname in ["conv2d", "conv2d_relu"] + else _conv1d_op_with_squeeze + ) + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using Conv2dPrepackParam from conv_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + op_f, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using ConvPrepackedParam. + param = get_script_object(gm, param_node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + op_f, + ( + inp_node, + param_0, + param_1, + param.stride(), # type: ignore[attr-defined] + param.padding(), # type: ignore[attr-defined] + param.dilation(), # type: ignore[attr-defined] + param.groups(), # type: ignore[attr-defined] + ), + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Linear specfic transformation function.""" + scale_node, zero_point_node = node.args[2], node.args[3] + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using LinearPrepackParam from linear_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using LinearPackedParams. + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1) + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_op_where_last_two_arguments_are_scale_and_zero_point( + gm: torch.fx.GraphModule, node: torch.fx.Node +): + """ + This transformation function can be used for function where the last two + parameters are scale and zero point. Additionally, the function's parameters + do not need any unpacking. + """ + to_standard_op = { + "mul": torch.ops.aten.mul, + "mul_relu": torch.ops.aten.mul, + "add": torch.ops.aten.add, + "add_relu": torch.ops.aten.add, + "softmax": torch.ops.aten.softmax, + "cat": torch.ops.aten.cat, + "hardswish": torch.ops.aten.hardswish, + } + + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2])) + return op_res_node, scale_node, zero_point_node + + +def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Transform scalar overload for basic arithmetic.""" + to_standard_op = { + "mul": torch.ops.aten.mul.Scalar, + "add": torch.ops.aten.add.Scalar, + } + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_res_node = gm.graph.call_function(to_standard_op[opname], args) + return op_res_node, _SCALE, _ZERO_POINT + + +def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node): + """ + Transformation for functions under prepacked namespace, where they share + the same handling logic that [...]OpContext contains all parameters. + """ + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_f = None + if opname == "conv2d_clamp_run": + op_f = torch.ops.aten.conv2d + elif opname == "linear_clamp_run": + op_f = torch.ops.aten.linear + else: + raise RuntimeError(f"Invalid operator {opname}") + + assert isinstance(args[1], torch.fx.Node) + so = get_script_object(gm, args[1]) + + func_args = [] + func_args += [args[0]] + func_args += so.unpack()[:2] # type: ignore[attr-defined] + if opname == "conv2d_clamp_run": + func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:] + + op_res_node = gm.graph.call_function(op_f, tuple(func_args)) + return op_res_node + + +def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node): + args = node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function( + torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3]) + ) + op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0)) + return op_res_node, scale_node, zero_point_node + + +def fx_transform_quantized_op_to_standard_op( + gm: torch.fx.GraphModule, node: torch.fx.Node +) -> torch.fx.Node: + global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE + + assert isinstance(node.target, torch._ops.OpOverload) + opname, overload = node.target._opname, node.target._overloadname + + key = f"{opname}.{overload}" + opname_to_transform_f = { + "conv1d.new": _transform_conv_with_packedparam, + "conv1d_relu.new": _transform_conv_with_packedparam, + "conv1d.default": _transform_conv_with_packedparam, + "conv1d_relu.default": _transform_conv_with_packedparam, + "conv2d.new": _transform_conv_with_packedparam, + "conv2d_relu.new": _transform_conv_with_packedparam, + "conv2d.default": _transform_conv_with_packedparam, + "conv2d_relu.default": _transform_conv_with_packedparam, + "linear.default": _transform_linear_with_packedparam, + "linear_relu.default": _transform_linear_with_packedparam, + "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "batch_norm2d.default": _transform_batch_norm, + "mul.Scalar": _transform_scalar_arithmetic, + "add.Scalar": _transform_scalar_arithmetic, + } + + if f"{key}" not in opname_to_transform_f: + raise RuntimeError(f"Unsupported quantized op during transformation: {key}") + + op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node) + + # Add fused activation layer. + op_res_node = insert_fused_activation_node(gm, opname, op_res_node) + _SCALE, _ZERO_POINT = scale_node, zero_point_node + + assert _INPUT_Q_DTYPE is not None + qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE) + q_fx_node = insert_quantized_node( + gm, + op_res_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + None, + torch.per_tensor_affine, + ) + return dq_fx_node + + +def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): + """ + Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with + PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv). + + Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y + + After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y + + (qd == quantized_decomposed library, q = quantize, dq = dequantize) + ^ + | + getattr(w), getattr(b) from Conv2dParamPrepack + + During each iteration, the transformation spits out the transformed operator, its quantized output, + and its dequantized value together. We did this because dequantization need to use the + scale and zero point parameters from the quantization to recover the approximate original value. After each + iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear). + + For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject. + During the transformation, we unpack those objects, get their dequantized tensor, populate those + as attributes to the module, and use getattr to access them. + + One exception in the transformation is conv_prepack and linear_prepack. Those calls pack + weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls. + During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the + quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters + to the operator by converting them to a getattr fx.node. + + For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear + without the need of doing de/quantization. + + Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization + data type, which is the same across the entire program, but it only shows up in the very first quantization + call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar. + """ + + global _INPUT_Q_DTYPE + + quantized = False + + last_quantized_node = None + for node in gm.graph.nodes: + if isinstance(node.target, OpOverload): + with gm.graph.inserting_before(node): + namespace, opname = node.target.namespace, node.target._opname + if namespace == "quantized" and opname not in [ + "conv_prepack", + "linear_prepack", + ]: + quantized = True + fx_node = fx_transform_quantized_op_to_standard_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "prepacked": + quantized = True + fx_node = _transform_prepacked_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "aten" and opname == "quantize_per_tensor": + inp_node, scale_node, zero_point_node, dtype_node = node.args + dtype_node = fx_enum_to_dtype(gm, dtype_node) + _INPUT_Q_DTYPE = dtype_node + qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node) + q_fx_node = insert_quantized_node( + gm, + inp_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + None, + torch.per_tensor_affine, + ) + node.replace_all_uses_with(dq_fx_node) + last_quantized_node = dq_fx_node + elif namespace == "aten" and opname == "dequantize": + assert last_quantized_node is not None + node.replace_all_uses_with(last_quantized_node) + else: + last_quantized_node = node + + # Post-processing again to remove legacy ScriptObjects and quantizated tensors + # stored as attributes or in the buffer. This is used to clean up the GraphModule + # to not trigger tracing errors like missing __obj_flatten__ functions. + def _clean_attr(mod: torch.nn.Module): + for submod in mod.modules(): + attr_names_to_clean = set() + for k, v in submod.__dict__.items(): + if isinstance(v, torch.ScriptObject): + attr_names_to_clean.add(k) + if k == "_buffers": + buffer_name_to_clean = set() + for b_name, b_value in v.items(): + if isinstance(b_value, torch.Tensor) and b_value.dtype in [ + torch.qint8, + torch.quint8, + ]: + buffer_name_to_clean.add(b_name) + for b_name in buffer_name_to_clean: + v.pop(b_name, None) + for attr_name in attr_names_to_clean: + delattr(submod, attr_name) + + if quantized: + """ + TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily + bypass test cases. + + The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing + will throw errors. However, the current way of SetAttr does inplace update to attributes, so + this pass regard them as dead code and remove them. Below is an example of GraphModule before + and after the dead code elimination pass. + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data = self.data; data = None + data_1 = self.data + add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None + data_2 = self.data + copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + """ + gm.graph.eliminate_dead_code() + _clean_attr(gm) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..96104a83ce8b67b62d5aac1fca11cea395ccf2d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs + +import torch +from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled + +from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +def _is_set_grad_enabled_node(node: torch.fx.Node): + return ( + node + and node.op == "call_function" + and node.target == torch._C._set_grad_enabled + ) + + +def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False): + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target == torch._C._set_grad_enabled + ): + return ( + first_non_ph.args[0] != torch.is_grad_enabled() + if omit_if_same_with_ambient + else True + ) + return False + + +def _replace_with_hop(node: torch.fx.Node): + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node) + if len(set_grad_nodes) > 0: + assert len(set_grad_nodes) == 1 + set_grad_node = set_grad_nodes[0] + _replace_with_hop_helper( + node, set_grad_node, _is_set_grad_enabled_node, wrap_with_set_grad_enabled + ) + sub_graph.erase_node(set_grad_node) + + +def _remove_set_grad_and_inline(node: torch.fx.Node): + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + nodes_map( + sub_graph.nodes, + lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n, + ) + node_inline_(node) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature +): + """ + Helper function for replace_set_grad_with_hop_pass(). + Split the graph module into multiple subgraphs based on the set_grad_enabled nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + """ + need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # sequential_split returns a new graph module that could have different output + # args names. We need to fix the graph signature. + new_gm = sequential_split(gm, _is_set_grad_enabled_node) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): + if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): + _replace_with_hop(node) + else: + _remove_set_grad_and_inline(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature): + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..6723ac5f86a6cbf703e886318ee44d5ebfc2e13f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from typing import Dict, Optional +import torch +from torch._ops import OpOverload, HigherOrderOperator +from torch._export.error import InternalError +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse + + +__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] + + +_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { + torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, +} + + +def is_view_op(schema: torch._C.FunctionSchema) -> bool: + if len(schema.arguments) == 0: + return False + alias_info = schema.arguments[0].alias_info + return (alias_info is not None) and (not alias_info.is_write) + + +def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]: + if is_view_op(schema) and schema.name.startswith("aten::"): + view_op_name = schema.name.split("::")[1] + view_op_overload = ( + schema.overload_name + if schema.overload_name != "" + else "default" + ) + view_copy_op_name = view_op_name + "_copy" + if not hasattr(torch.ops.aten, view_copy_op_name): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name) + + if not hasattr(view_copy_op_overload_packet, view_op_overload): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + return getattr(view_copy_op_overload_packet, view_op_overload) + + return None + + +class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Our backend expects pure functional operators. For efficiency + purposes, we keep view ops around while functionalizing the exported + program. This pass replaces view ops with view copy ops for backends that + need AOT memory planning. + """ + def call_operator(self, op, args, kwargs, meta): + if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: + return super().call_operator( + (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta + ) + + if isinstance(op, HigherOrderOperator): + return super().call_operator(op, args, kwargs, meta) + + if view_copy_op := get_view_copy_of_view_op(op._schema): + return super().call_operator(view_copy_op, args, kwargs, meta) + + return super().call_operator(op, args, kwargs, meta) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_with_hop_pass_util.py b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_with_hop_pass_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ca55025bd0a8e0abbe59b28012131bddc1e51f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_with_hop_pass_util.py @@ -0,0 +1,178 @@ +# mypy: allow-untyped-defs + +import contextlib +import copy +import operator +from typing import Callable + +import torch +from torch._ops import HigherOrderOperator + +from ..utils import node_replace_, nodes_map + + +def _replace_with_hop_helper( + node: torch.fx.Node, + enter_block_node: torch.fx.Node, + node_filter: Callable, + wrap_hoo: HigherOrderOperator, +): + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + + def set_hoo_node_meta(call_func_node): + call_func_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + call_func_node.meta["torch_fn"] = ( + f"{wrap_hoo.__name__}", + f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}", + ) + if isinstance(output_args, (tuple, list)): + call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args) + elif isinstance(output_args, torch.fx.Node): + call_func_node.meta["val"] = (output_args.meta["val"],) + + with graph.inserting_before(node): + get_attr_node = graph.get_attr(node.target) + get_attr_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + output_node = next(iter(reversed(sub_gm.graph.nodes)), None) + # Split_module pass intentially doesn't add output node + # if the graph doesn't return anything. + # TODO (tmanlaibaatar) Figure out if this is right behaviour + # for split_module + if isinstance(output_node, torch.fx.Node) and output_node.op != "output": + output_node = None + if output_node is not None: + assert len(output_node.args) == 1 + output_args = output_node.args[0] + enter_block_node_args = enter_block_node.args + if isinstance(output_args, (tuple, list)): + call_func_node = graph.call_function( + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + ) + # Create the metadata + set_hoo_node_meta(call_func_node) + node_replace_(node, call_func_node) + + # Rename the name of getitem nodes to the actual name of its contents + # for passing verifier and better readability, also propagate metadata + for get_item_node in call_func_node.users.keys(): + idx: int = get_item_node.args[1] # type: ignore[assignment] + output_node = output_args[idx] + get_item_node._rename(output_node.name) + get_item_node.meta = output_node.meta + + elif isinstance(output_args, torch.fx.Node): + call_func_node = graph.create_node( + "call_function", + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + output_args.name, + ) + # Modify the subgraph to output a singleton list. + output_node.args = ((output_args,),) + # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph. + get_item_node = graph.create_node( + "call_function", + operator.getitem, + (call_func_node, 0), + {}, + ) + # Create the metadata + get_item_node.meta = output_args.meta + set_hoo_node_meta(call_func_node) + node_replace_(node, get_item_node) + else: + raise NotImplementedError( + f"repalce_with_hop_pass doesnt' support output type {type(output_args)}" + ) + else: + # TODO (shangdiy): remove this line, since the export graph can be non-functional + node.graph.erase_node(node) + + +def _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm: torch.fx.GraphModule, + graph_signature, + maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None], +): + """ + Helper function for replacing graph nodse with higher order nodes. + For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`. + """ + # new_gm is a new graph module that could have different output args names. + # We need to fix the graph signature. + replace_ctx = contextlib.nullcontext() + new_signature = None + if graph_signature is not None: + # Cannot deep copy a real ScriptObject, which is referenced + # in the FakeScriptObject. Copy should be good enough to guard + # against accidental mutation to original graph_signature. + new_signature = copy.copy(graph_signature) + new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output"))) + assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len( + new_signature.output_specs + ) + for arg_node, out_spec in zip( + new_gm_out_node.args[0], new_signature.output_specs + ): + if arg_node is None: + assert out_spec.arg.value is None + elif ( + isinstance(arg_node, torch.fx.Node) + and out_spec.arg.name != arg_node.name + ): + out_spec.arg.name = arg_node.name + + replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment] + + with replace_ctx: + nodes_map( + list(new_gm.graph.nodes), + lambda node: ( + maybe_inline_or_replace_with_hop(node) + if node.op == "call_module" + else node + ), + ) + new_gm.recompile() + return new_gm, new_signature + + +def _replace_with_hop_pass_helper( + gm: torch.fx.GraphModule, + graph_signature, + sequential_split_and_maybe_inline_subgraphs: Callable, +): + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs( + gm, graph_signature + ) + # recursively call + for node in new_gm.graph.nodes: + if node.op == "get_attr": + subgm = getattr(new_gm, node.target) + if not isinstance(subgm, torch.fx.GraphModule): + continue + new_subgm, _ = _replace_with_hop_pass_helper( + subgm, + None, + sequential_split_and_maybe_inline_subgraphs, + ) + setattr(new_gm, node.target, new_subgm) + + new_gm.recompile() + new_gm.graph.lint() + return new_gm, new_signature diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..790185347f6e74514aeff6828159b53b2faa546b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.py b/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..ce102b39367ad0178fe46236a428d1cf924dfe3c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.py @@ -0,0 +1,381 @@ +# NOTE: This is a placeholder for iterating on export serialization schema design. +# Anything is subject to change and no guarantee is provided at this point. + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Dict, List, Optional, Tuple + +from torch._export.serde.union import _Union + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (7, 3) +TREESPEC_VERSION = 1 + + +class ScalarType(IntEnum): + UNKNOWN = 0 + BYTE = 1 + CHAR = 2 + SHORT = 3 + INT = 4 + LONG = 5 + HALF = 6 + FLOAT = 7 + DOUBLE = 8 + COMPLEXHALF = 9 + COMPLEXFLOAT = 10 + COMPLEXDOUBLE = 11 + BOOL = 12 + BFLOAT16 = 13 + + +class Layout(IntEnum): + Unknown = 0 + SparseCoo = 1 + SparseCsr = 2 + SparseCsc = 3 + SparseBsr = 4 + SparseBsc = 5 + _mkldnn = 6 + Strided = 7 + + +class MemoryFormat(IntEnum): + Unknown = 0 + ContiguousFormat = 1 + ChannelsLast = 2 + ChannelsLast3d = 3 + PreserveFormat = 4 + + +@dataclass +class Device: + type: str + index: Optional[int] = None + + +@dataclass(repr=False) +class SymExprHint(_Union): + as_int: int + as_float: float + as_bool: bool + + +# This is for storing the symbolic expressions behind symints/symfloats/symbools +# For example, we can get something like +# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) +# if we also have the hint that s0 and s1 are both 2. +@dataclass +class SymExpr: + expr_str: str + hint: Optional[SymExprHint] = None + + +@dataclass(repr=False) +class SymInt(_Union): + as_expr: SymExpr + as_int: int + + +@dataclass(repr=False) +class SymBool(_Union): + as_expr: SymExpr + as_bool: bool + + +@dataclass +class TensorMeta: + dtype: ScalarType + sizes: List[SymInt] + requires_grad: bool + device: Device + strides: List[SymInt] + storage_offset: SymInt + layout: Layout + + +# In most cases we will use the "as_name" field to store arguments which are +# SymInts. +# The "as_int" field is used in the case where we have a list containing a mix +# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to +# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints +# to the "as_int" field. +@dataclass(repr=False) +class SymIntArgument(_Union): + as_name: str + as_int: int + + +# In most cases we will use the "as_name" field to store arguments which are +# SymBools. +# The "as_bool" field is used in the case where we have a list containing a mix +# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to +# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools +# to the "as_bool" field. +@dataclass(repr=False) +class SymBoolArgument(_Union): + as_name: str + as_bool: bool + + +@dataclass +class TensorArgument: + name: str + + +@dataclass +class TokenArgument: + name: str + + +# This is use for storing the contents of a list which contain optional tensors +# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the +# type List[OptionalTensorArgument], with tensor values seiralized to the +# "as_tensor" field, and None values serialized to the "as_none" field. +@dataclass(repr=False) +class OptionalTensorArgument(_Union): + as_tensor: TensorArgument + as_none: Tuple[()] + + +@dataclass +class GraphArgument: + name: str + graph: 'Graph' + + +@dataclass +class CustomObjArgument: + name: str + class_fqn: str + + +# This is actually a union type +@dataclass(repr=False) +class Argument(_Union): + as_none: Tuple[()] + as_tensor: TensorArgument + as_tensors: List[TensorArgument] + as_int: int + as_ints: List[int] + as_float: float + as_floats: List[float] + as_string: str + as_strings: List[str] + as_sym_int: SymIntArgument + as_sym_ints: List[SymIntArgument] + as_scalar_type: ScalarType + as_memory_format: MemoryFormat + as_layout: Layout + as_device: Device + as_bool: bool + as_bools: List[bool] + as_sym_bool: SymBoolArgument + as_sym_bools: List[SymBoolArgument] + as_graph: GraphArgument + as_optional_tensors: List[OptionalTensorArgument] + as_custom_obj: CustomObjArgument + as_operator: str + + +@dataclass +class NamedArgument: + # Argument name from the operator schema + name: str + arg: Argument + + +@dataclass +class Node: + target: str + inputs: List[NamedArgument] + outputs: List[Argument] + metadata: Dict[str, str] + + +@dataclass +class Graph: + inputs: List[Argument] + outputs: List[Argument] + nodes: List[Node] + tensor_values: Dict[str, TensorMeta] + sym_int_values: Dict[str, SymInt] + sym_bool_values: Dict[str, SymBool] + # This is for deserializing the submodule graphs from higher order ops + # (ex. cond, map) where single tensor returns will just return a single + # tensor, rather than following export schema and returning a singleton + # list. + is_single_tensor_return: bool = False + custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) + + +@dataclass +class UserInputSpec: + # Actually, only tensors and SymInts are allowed here + arg: Argument + + +@dataclass(repr=False) +class ConstantValue(_Union): + as_none: Tuple[()] + as_int: int + as_float: float + as_string: str + as_bool: bool + + +@dataclass +class ConstantInputSpec: + name: str + value: ConstantValue + + +@dataclass +class InputToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class InputToBufferSpec: + arg: TensorArgument + buffer_name: str + persistent: bool + + + +@dataclass +class InputToTensorConstantSpec: + arg: TensorArgument + tensor_constant_name: str + + +@dataclass +class InputToCustomObjSpec: + arg: CustomObjArgument + custom_obj_name: str + + +@dataclass +class InputTokenSpec: + arg: TokenArgument + + +@dataclass(repr=False) +class InputSpec(_Union): + user_input: UserInputSpec + parameter: InputToParameterSpec + buffer: InputToBufferSpec + tensor_constant: InputToTensorConstantSpec + custom_obj: InputToCustomObjSpec + token: InputTokenSpec + constant_input: ConstantInputSpec + + +@dataclass +class UserOutputSpec: + arg: Argument + + +@dataclass +class LossOutputSpec: + arg: TensorArgument + + +@dataclass +class BufferMutationSpec: + arg: TensorArgument + buffer_name: str + + +@dataclass +class GradientToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class GradientToUserInputSpec: + arg: TensorArgument + user_input_name: str + + +@dataclass +class UserInputMutationSpec: + arg: TensorArgument + user_input_name: str + + +@dataclass +class OutputTokenSpec: + arg: TokenArgument + + +@dataclass(repr=False) +class OutputSpec(_Union): + user_output: UserOutputSpec + loss_output: LossOutputSpec + buffer_mutation: BufferMutationSpec + gradient_to_parameter: GradientToParameterSpec + gradient_to_user_input: GradientToUserInputSpec + user_input_mutation: UserInputMutationSpec + token: OutputTokenSpec + + +@dataclass +class GraphSignature: + input_specs: List[InputSpec] + output_specs: List[OutputSpec] + + +@dataclass +class RangeConstraint: + min_val: int + max_val: int + + +@dataclass +class ModuleCallSignature: + inputs: List[Argument] + outputs: List[Argument] + + # These are serialized by calling pytree.treespec_loads + # And deserialized by calling pytree.treespec_dumps + in_spec: str + out_spec: str + + +@dataclass +class ModuleCallEntry: + fqn: str + signature: Optional[ModuleCallSignature] = None + + +@dataclass +class GraphModule: + graph: Graph + signature: GraphSignature + # This is used for unflattening, by tracking the calling structure of all of + # the modules in order to unflatten the modules back to the eager calling + # conventions. + module_call_graph: List[ModuleCallEntry] + metadata: Dict[str, str] = field(default_factory=dict) + + +# Invariant: Every time a change is made to the schema, one of the versions +# should be upadted. +@dataclass +class SchemaVersion: + major: int # Major version number is bumped every time a breaking change is made. + minor: int # Minor version number is bumped when a compatible change is made. + + +@dataclass +class ExportedProgram: + graph_module: GraphModule + # Key is the opset namespace (ex. aten), and value is the version number + opset_version: Dict[str, int] + range_constraints: Dict[str, RangeConstraint] + schema_version: SchemaVersion + verifiers: List[str] = field(default_factory=list) + torch_version: str = "<=2.4" diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml b/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25a9a295ad0b97967885cbb715cdfbb553dfb4eb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml @@ -0,0 +1,437 @@ +# @generated by update_schema.py +# checksum<<923abf371a1f8802cacb037d409d28273867777a98f6542fba28616c2b92b639>> +Argument: + kind: union + fields: + as_none: + type: Tuple[()] + as_tensor: + type: TensorArgument + as_tensors: + type: List[TensorArgument] + as_int: + type: int + as_ints: + type: List[int] + as_float: + type: float + as_floats: + type: List[float] + as_string: + type: str + as_strings: + type: List[str] + as_sym_int: + type: SymIntArgument + as_sym_ints: + type: List[SymIntArgument] + as_scalar_type: + type: ScalarType + as_memory_format: + type: MemoryFormat + as_layout: + type: Layout + as_device: + type: Device + as_bool: + type: bool + as_bools: + type: List[bool] + as_sym_bool: + type: SymBoolArgument + as_sym_bools: + type: List[SymBoolArgument] + as_graph: + type: GraphArgument + as_optional_tensors: + type: List[OptionalTensorArgument] + as_custom_obj: + type: CustomObjArgument + as_operator: + type: str +BufferMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str +ConstantInputSpec: + kind: struct + fields: + name: + type: str + value: + type: ConstantValue +ConstantValue: + kind: union + fields: + as_none: + type: Tuple[()] + as_int: + type: int + as_float: + type: float + as_string: + type: str + as_bool: + type: bool +CustomObjArgument: + kind: struct + fields: + name: + type: str + class_fqn: + type: str +Device: + kind: struct + fields: + type: + type: str + index: + type: Optional[int] + default: None +ExportedProgram: + kind: struct + fields: + graph_module: + type: GraphModule + opset_version: + type: Dict[str, int] + range_constraints: + type: Dict[str, RangeConstraint] + schema_version: + type: SchemaVersion + verifiers: + type: List[str] + default: '[]' + torch_version: + type: str + default: <=2.4 +GradientToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +GradientToUserInputSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +Graph: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + nodes: + type: List[Node] + tensor_values: + type: Dict[str, TensorMeta] + sym_int_values: + type: Dict[str, SymInt] + sym_bool_values: + type: Dict[str, SymBool] + is_single_tensor_return: + type: bool + default: 'False' + custom_obj_values: + type: Dict[str, CustomObjArgument] + default: '{}' +GraphArgument: + kind: struct + fields: + name: + type: str + graph: + type: Graph +GraphModule: + kind: struct + fields: + graph: + type: Graph + signature: + type: GraphSignature + module_call_graph: + type: List[ModuleCallEntry] + metadata: + type: Dict[str, str] + default: '{}' +GraphSignature: + kind: struct + fields: + input_specs: + type: List[InputSpec] + output_specs: + type: List[OutputSpec] +InputSpec: + kind: union + fields: + user_input: + type: UserInputSpec + parameter: + type: InputToParameterSpec + buffer: + type: InputToBufferSpec + tensor_constant: + type: InputToTensorConstantSpec + custom_obj: + type: InputToCustomObjSpec + token: + type: InputTokenSpec + constant_input: + type: ConstantInputSpec +InputToBufferSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str + persistent: + type: bool +InputToCustomObjSpec: + kind: struct + fields: + arg: + type: CustomObjArgument + custom_obj_name: + type: str +InputToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +InputToTensorConstantSpec: + kind: struct + fields: + arg: + type: TensorArgument + tensor_constant_name: + type: str +InputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +Layout: + kind: enum + fields: + Unknown: 0 + SparseCoo: 1 + SparseCsr: 2 + SparseCsc: 3 + SparseBsr: 4 + SparseBsc: 5 + _mkldnn: 6 + Strided: 7 +LossOutputSpec: + kind: struct + fields: + arg: + type: TensorArgument +MemoryFormat: + kind: enum + fields: + Unknown: 0 + ContiguousFormat: 1 + ChannelsLast: 2 + ChannelsLast3d: 3 + PreserveFormat: 4 +ModuleCallEntry: + kind: struct + fields: + fqn: + type: str + signature: + type: Optional[ModuleCallSignature] + default: None +ModuleCallSignature: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + in_spec: + type: str + out_spec: + type: str +NamedArgument: + kind: struct + fields: + name: + type: str + arg: + type: Argument +Node: + kind: struct + fields: + target: + type: str + inputs: + type: List[NamedArgument] + outputs: + type: List[Argument] + metadata: + type: Dict[str, str] +OptionalTensorArgument: + kind: union + fields: + as_tensor: + type: TensorArgument + as_none: + type: Tuple[()] +OutputSpec: + kind: union + fields: + user_output: + type: UserOutputSpec + loss_output: + type: LossOutputSpec + buffer_mutation: + type: BufferMutationSpec + gradient_to_parameter: + type: GradientToParameterSpec + gradient_to_user_input: + type: GradientToUserInputSpec + user_input_mutation: + type: UserInputMutationSpec + token: + type: OutputTokenSpec +OutputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +RangeConstraint: + kind: struct + fields: + min_val: + type: int + max_val: + type: int +ScalarType: + kind: enum + fields: + UNKNOWN: 0 + BYTE: 1 + CHAR: 2 + SHORT: 3 + INT: 4 + LONG: 5 + HALF: 6 + FLOAT: 7 + DOUBLE: 8 + COMPLEXHALF: 9 + COMPLEXFLOAT: 10 + COMPLEXDOUBLE: 11 + BOOL: 12 + BFLOAT16: 13 +SchemaVersion: + kind: struct + fields: + major: + type: int + minor: + type: int +SymBool: + kind: union + fields: + as_expr: + type: SymExpr + as_bool: + type: bool +SymBoolArgument: + kind: union + fields: + as_name: + type: str + as_bool: + type: bool +SymExpr: + kind: struct + fields: + expr_str: + type: str + hint: + type: Optional[SymExprHint] + default: None +SymExprHint: + kind: union + fields: + as_int: + type: int + as_float: + type: float + as_bool: + type: bool +SymInt: + kind: union + fields: + as_expr: + type: SymExpr + as_int: + type: int +SymIntArgument: + kind: union + fields: + as_name: + type: str + as_int: + type: int +TensorArgument: + kind: struct + fields: + name: + type: str +TensorMeta: + kind: struct + fields: + dtype: + type: ScalarType + sizes: + type: List[SymInt] + requires_grad: + type: bool + device: + type: Device + strides: + type: List[SymInt] + storage_offset: + type: SymInt + layout: + type: Layout +TokenArgument: + kind: struct + fields: + name: + type: str +UserInputMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +UserInputSpec: + kind: struct + fields: + arg: + type: Argument +UserOutputSpec: + kind: struct + fields: + arg: + type: Argument +SCHEMA_VERSION: +- 7 +- 3 +TREESPEC_VERSION: 1 diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/schema_check.py b/.venv/lib/python3.11/site-packages/torch/_export/serde/schema_check.py new file mode 100644 index 0000000000000000000000000000000000000000..b22b9778819e73635aa6d37a254aa4b643abd5f5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/schema_check.py @@ -0,0 +1,286 @@ +# mypy: allow-untyped-defs +import dataclasses +import hashlib +import re +import typing +from enum import IntEnum +from typing import Any, Dict, Optional, Union + +from torch._export.serde import schema +from torch._export.serde.union import _Union + + +class SchemaUpdateError(Exception): + pass + + +def _check(x, msg): + if not x: + raise SchemaUpdateError(msg) + + +def _staged_schema(): + ret: Dict[str, Any] = {} + defs = {} + + def _handle_aggregate(ty): + def dump_type(t): + if isinstance(t, type): + return t.__name__ + elif isinstance(t, str): + assert t in defs + return t + elif o := typing.get_origin(t): + # Lemme know if there's a better way to do this. + if o == list: + head = "List" + elif o == dict: + head = "Dict" + elif o == tuple: + if typing.get_args(t) == (): + return "Tuple[()]" + head = "Tuple" + elif o == Union: + args = typing.get_args(t) + assert len(args) == 2 and args[1] == type(None) + return f"Optional[{dump_type(args[0])}]" + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + return ( + f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]" + ) + elif t == (): + return "()" + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + + def dump_field(f): + t = dump_type(f.type) + ret = {"type": t} + + value = dataclasses.MISSING + if f.default is not dataclasses.MISSING: + value = f.default + elif f.default_factory is not dataclasses.MISSING: + value = f.default_factory() + + if t.startswith("Optional[") and value is not None: + raise AssertionError( + f"Optional field {ty.__name__}.{f.name} must have default value to be None." + ) + + if value is not dataclasses.MISSING: + default = str(value) + ret["default"] = default + return ret + + return {f.name: dump_field(f) for f in dataclasses.fields(ty)} + + def _handle_int_enum(name, ty): + ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + + def _handle_struct(name, ty): + ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)} + + def _handle_union(name, ty): + ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)} + + for name in dir(schema): + if name.startswith("_"): + continue + + value = getattr(schema, name) + + if hasattr(value, "__module__") and value.__module__ != schema.__name__: + continue + + defs[name] = value + + for name, value in defs.items(): + if isinstance(value, type): + if issubclass(value, IntEnum): + _handle_int_enum(name, value) + elif dataclasses.is_dataclass(value): + if issubclass(value, _Union): + _handle_union(name, value) + else: + _handle_struct(name, value) + else: + raise AssertionError(f"Unknown schema type {name}: {value}") + elif isinstance(value, (int, tuple)): + assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") + else: + raise AssertionError(f"Unknown variable {name}: {value}") + + ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in ret["SCHEMA_VERSION"]) + ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert ret["TREESPEC_VERSION"] > 0 + return ret + + +def _diff_schema(dst, src): + additions = {key: src[key] for key in src.keys() - dst.keys()} + subtractions = {key: dst[key] for key in dst.keys() - src.keys()} + + common_keys = src.keys() & dst.keys() + + versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} + common_keys -= versions + + for key in common_keys: + src_kind = src[key]["kind"] + src_fields = src[key]["fields"] + dst_kind = dst[key]["kind"] + dst_fields = dst[key]["fields"] + _check( + src_kind == dst_kind, + f"Type {key} changed kind from {dst_kind} to {src_kind}", + ) + assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) + added_fields = { + key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() + } + subtracted_fields = { + key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() + } + common_fields = src_fields.keys() & dst_fields.keys() + + for field in common_fields: + src_field = src_fields[field] + dst_field = dst_fields[field] + if src_kind == "struct": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + if "default" in src_field and "default" not in dst_field: + added_fields[field] = {} + added_fields[field]["default"] = src_field["default"] + if "default" not in src_field and "default" in dst_field: + subtracted_fields[field] = {} + subtracted_fields[field]["default"] = dst_field["default"] + elif src_kind == "enum": + _check( + src_field == dst_field, + f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", + ) + elif src_kind == "union": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + else: + raise AssertionError(f"Unknown kind {src_kind}: {key}") + if len(added_fields) > 0: + assert key not in additions + additions[key] = {} + additions[key]["fields"] = added_fields + if len(subtracted_fields) > 0: + assert key not in subtractions + subtractions[key] = {} + subtractions[key]["fields"] = subtracted_fields + + return additions, subtractions + + +def _hash_schema(s): + return hashlib.sha256(repr(s).encode("utf-8")).hexdigest() + + +@dataclasses.dataclass +class _Commit: + result: Dict[str, Any] + checksum_result: str + path: str + additions: Dict[str, Any] + subtractions: Dict[str, Any] + base: Dict[str, Any] + checksum_base: Optional[str] + + +def update_schema(): + import importlib.resources + + if importlib.resources.is_resource(__package__, "schema.yaml"): + content = importlib.resources.read_text(__package__, "schema.yaml") + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) + _check(match is not None, "checksum not found in schema.yaml") + assert match is not None + checksum_base = match.group(1) + from yaml import load, Loader + + dst = load(content, Loader=Loader) + assert isinstance(dst, dict) + else: + checksum_base = None + dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} + + src = _staged_schema() + additions, subtractions = _diff_schema(dst, src) + return _Commit( + result=src, + checksum_result=_hash_schema(src), + path=__package__.replace(".", "/") + "/schema.yaml", + additions=additions, + subtractions=subtractions, + base=dst, + checksum_base=checksum_base, + ) + + +def check(commit: _Commit, force_unsafe: bool = False): + next_version = None + reason = "" + # Step 1: Detect major schema updates. + if len(commit.additions) > 0: + for k, v in commit.additions.items(): + if k not in commit.base: + continue + kind = commit.result[k]["kind"] + fields = v["fields"] + for f, d in fields.items(): + if "default" not in d and kind == "struct": + reason += ( + f"Field {k}.{f} is added to schema.py without a default value as an incomparible change " + + "which requires major version bump.\n" + ) + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + if k not in commit.result: + continue + for f in v["fields"]: + reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if force_unsafe: + reason += "--force-unsafe is used." + next_version = commit.result["SCHEMA_VERSION"] + else: + # Step 2: Detect minor schema updates. + if next_version is None and len(commit.additions) > 0: + for k, v in commit.additions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is added to schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + if next_version is None and len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is removed from schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + + return next_version, reason diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py b/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..44153ccc78eb42155a155caf742da91ab4c93cc0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py @@ -0,0 +1,2960 @@ +# mypy: allow-untyped-defs +import base64 +import copy +import copyreg +import dataclasses +import heapq +import inspect +import io +import json +import logging +import math +import operator +import re +import typing +import traceback + +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Any, + Callable, + cast, + Dict, + final, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + Union, +) + +import sympy + +import torch +import torch.export.exported_program as ep +from torch._export.serde.schema import SchemaVersion +from torch._export.verifier import load_verifier +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental import symbolic_shapes +from torch.utils import _pytree as pytree +from torch.utils._pytree import treespec_dumps, treespec_loads +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + +from .schema import ( # type: ignore[attr-defined] + Argument, + BufferMutationSpec, + ConstantInputSpec, + ConstantValue, + CustomObjArgument, + Device, + ExportedProgram, + GradientToParameterSpec, + GradientToUserInputSpec, + Graph, + GraphArgument, + GraphModule, + GraphSignature, + InputSpec, + InputToBufferSpec, + InputToCustomObjSpec, + InputTokenSpec, + InputToParameterSpec, + InputToTensorConstantSpec, + Layout, + LossOutputSpec, + MemoryFormat, + ModuleCallEntry, + ModuleCallSignature, + NamedArgument, + Node, + OptionalTensorArgument, + OutputSpec, + OutputTokenSpec, + RangeConstraint, + ScalarType, + SCHEMA_VERSION, + SymBool, + SymBoolArgument, + SymExpr, + SymExprHint, + SymInt, + SymIntArgument, + TensorArgument, + TensorMeta, + TokenArgument, + TREESPEC_VERSION, + UserInputMutationSpec, + UserInputSpec, + UserOutputSpec, +) +from .union import _Union +from ..utils import remove_proxy_from_state_dict + +__all__ = [ + "serialize", + "GraphModuleSerializer", + "ExportedProgramSerializer", + "GraphModuleDeserializer", + "ExportedProgramDeserializer", +] + +log = logging.getLogger(__name__) + + +class SerializeError(RuntimeError): + pass + + +def _reverse_map(d: Dict[Any, Enum]): + return {v.value: k for k, v in d.items()} + + +MetaType = Union[ + FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument +] + + +ST_DELIMITER = ";" + +_TORCH_TO_SERIALIZE_DTYPE = { + torch.uint8: ScalarType.BYTE, + torch.int8: ScalarType.CHAR, + torch.int16: ScalarType.SHORT, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.float64: ScalarType.DOUBLE, + torch.complex32: ScalarType.COMPLEXHALF, + torch.complex64: ScalarType.COMPLEXFLOAT, + torch.complex128: ScalarType.COMPLEXDOUBLE, + torch.bool: ScalarType.BOOL, + torch.bfloat16: ScalarType.BFLOAT16, +} + + +_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_LAYOUT = { + torch.sparse_coo: Layout.SparseCoo, + torch.sparse_csr: Layout.SparseCsr, + torch.sparse_csc: Layout.SparseCsc, + torch.sparse_bsr: Layout.SparseBsr, + torch.sparse_bsc: Layout.SparseBsc, + torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined] + torch.strided: Layout.Strided, +} + + +_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_MEMORY_FORMAT = { + torch.contiguous_format: MemoryFormat.ContiguousFormat, + torch.channels_last: MemoryFormat.ChannelsLast, + torch.channels_last_3d: MemoryFormat.ChannelsLast3d, + torch.preserve_format: MemoryFormat.PreserveFormat, +} + + +_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type] + + +_SYM_INT_OPS = { + operator.mul, + operator.add, + operator.sub, + operator.floordiv, + operator.mod, + operator.pow, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_sqrt, +} + + +_SYM_BOOL_OPS = { + operator.eq, + operator.ne, + operator.le, + operator.ge, + operator.lt, + operator.gt, + torch.sym_not, +} + + +assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_INT_OPS) +assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_BOOL_OPS) + + +@dataclass +class SerializedArtifact: + exported_program: bytes + state_dict: bytes + constants: bytes + example_inputs: bytes + + +@dataclass +class _SerializedProgram: + exported_program: ExportedProgram + state_dict: bytes + constants: bytes + example_inputs: bytes + + +def deserialize_device(d: Device) -> torch.device: + if d.index is None: + return torch.device(type=d.type) # type: ignore[call-overload] + return torch.device(type=d.type, index=d.index) + + +def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: + if isinstance(s, (torch.SymInt, sympy.Symbol, int)): + if symbolic_shapes.is_concrete_int(s): + return SymInt.create(as_int=int(s)) + else: + assert isinstance(s, (torch.SymInt, sympy.Symbol)) + if s.node.hint is None: + return SymInt.create(as_expr=SymExpr(str(s))) + else: + return SymInt.create( + as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint)) + ) + else: + raise SerializeError( + f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: + if isinstance(s, (torch.SymBool, bool)): + if symbolic_shapes.is_concrete_bool(s): + return SymBool.create(as_bool=bool(s)) + else: + return SymBool.create(as_expr=SymExpr(expr_str=str(s))) + else: + raise SerializeError( + f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" + ) + + +def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: + """ + Extract a TensorMeta describing `t`. + """ + return TensorMeta( + dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype], + sizes=[serialize_sym_int(s) for s in t.shape], + requires_grad=t.requires_grad, + device=Device(type=t.device.type, index=t.device.index), + strides=[serialize_sym_int(s) for s in t.stride()], + storage_offset=serialize_sym_int(0), # TODO needs to be fixed. + layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], + ) + + +_CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None + + +def _reduce_fake_tensor(fake_tensor: FakeTensor): + is_parameter = isinstance(fake_tensor, torch.nn.Parameter) + tensor_meta = serialize_tensor_meta(fake_tensor) + tensor_meta_bytes = json.dumps( + _dataclass_to_dict(tensor_meta), cls=EnumEncoder + ).encode("utf-8") + return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) + + +def _reconstruct_fake_tensor( + serialized_tensor_meta: bytes, is_parameter: bool +) -> FakeTensor: + # Deserialize the bytes into a TensorMeta + json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) + tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) + # Find the current fake mode + assert ( + _CURRENT_DESERIALIZER is not None + ), "Need access to current deserializer state" + fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) + if is_parameter: + fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] + return fake_tensor + + +def serialize_torch_artifact(artifact: Optional[Any]) -> bytes: + if artifact is None: + return b"" + + assert ( + FakeTensor not in copyreg.dispatch_table + ), "Refusing to stomp on existing FakeTensor reducer" + try: + copyreg.pickle(FakeTensor, _reduce_fake_tensor) + buffer = io.BytesIO() + # This is a workaround for backend's tensor deserialization problem: + # unpickleTensor() always create a tensor on the device where it was originally saved + # This behavior is bad for multi-gpu training, as we wish to directly load the tensor + # on the designated device. + # For now, we simply move the tensor to cpu before saving. + # TODO: this should be fixed by deserialization instead. + torch.save(artifact, buffer) + return buffer.getvalue() + finally: + del copyreg.dispatch_table[FakeTensor] + + +def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...], bytes]): + if isinstance(serialized, (dict, tuple)): + return serialized + if len(serialized) == 0: + return {} + buffer = io.BytesIO(serialized) + buffer.seek(0) + # weights_only=False as we want to load custom objects here (e.g. ScriptObject) + artifact = torch.load(buffer, weights_only=False) + assert isinstance(artifact, (tuple, dict)) + return artifact + + +def _sympy_int_to_int(val: sympy.Expr, adjust: str): + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return math.inf + if val in (-sympy.oo, -int_oo): + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + + # TODO: Remove this adjustment when Ed gets rid of fractional ranges + log.warning( + "Export constraints cannot be non-integer expressions. Found " + "type %s, and value %s. We will attempt to %s " + "this value.", type(val), val, adjust + ) + + if adjust == "floor": + return math.floor(val) + elif adjust == "ceil": + return math.ceil(val) + else: + raise RuntimeError(f"Got invalid adjustment {adjust}") + + +def _int_to_sympy_int(val) -> sympy.Expr: + # Convert concrete int into simple sympy Integers + if val == math.inf: + return int_oo + if val == -math.inf: + return -int_oo + return sympy.Integer(val) + + +def serialize_range_constraints( + range_constraints: Dict[sympy.Symbol, ValueRanges] +) -> Dict[str, RangeConstraint]: + return { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type] + _sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type] + ) + for k, v in range_constraints.items() + } + + +def _get_schema_from_target(target): + if isinstance(target, torch._ops.OpOverload): + return target._schema + elif type(target) in _serialization_registry: + return _serialization_registry[type(target)].op_schema(target) + raise RuntimeError(f"Cannot find schema for {type(target)}") + + +def _is_single_tensor_return(target) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType) + + +def _is_single_tensor_list_return(target: Any) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + + if len(returns) != 1: + return False + return_type = returns[0].real_type + return isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ) + + +@dataclass +class GraphState: + inputs: List[Argument] = field(default_factory=list) + outputs: List[Argument] = field(default_factory=list) + nodes: List[Node] = field(default_factory=list) + tensor_values: Dict[str, TensorMeta] = field(default_factory=dict) + sym_int_values: Dict[str, SymInt] = field(default_factory=dict) + sym_bool_values: Dict[str, SymBool] = field(default_factory=dict) + is_single_tensor_return: bool = False + custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) + + +class Final(type): + def __new__(metacls, name, bases, classdict): + for b in bases: + if isinstance(b, Final): + raise TypeError(f"type '{b.__name__}' is not an acceptable base type") + return type.__new__(metacls, name, bases, dict(classdict)) + + +@final +class GraphModuleSerializer(metaclass=Final): + def __init__( + self, + graph_signature: ep.ExportGraphSignature, + module_call_graph: List[ep.ModuleCallEntry], + ): + self.graph_state = GraphState() + self.graph_signature = graph_signature + self.module_call_graph = module_call_graph + self.custom_objs: Dict[str, torch._C.ScriptObject] = {} + self.duplicate_getitem_nodes: Dict[str, str] = {} + + @contextmanager + def save_graph_state(self): + saved = self.graph_state + self.graph_state = GraphState() + try: + yield + finally: + self.graph_state = saved + + def handle_placeholder(self, node: torch.fx.Node): + assert node.op == "placeholder" + if isinstance(node.meta["val"], torch.Tensor): + graph_input = Argument.create(as_tensor=TensorArgument(name=node.name)) + self.graph_state.tensor_values[node.name] = serialize_tensor_meta( + node.meta["val"] + ) + elif isinstance(node.meta["val"], torch.SymInt): + raise AssertionError("SymInt graph input is not implemented yet.") + elif isinstance(node.meta["val"], (int, bool, str, float, type(None))): + graph_input = self.serialize_input(node.meta["val"]) + elif isinstance(node.meta["val"], ep.CustomObjArgument): + class_fqn = node.meta["val"].class_fqn + graph_input = Argument.create( + as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) + ) + self.graph_state.custom_obj_values[node.name] = ( + self.serialize_script_obj_meta(node.meta["val"]) + ) + else: + raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") + self.graph_state.inputs.append(graph_input) + + def handle_output(self, node: torch.fx.Node): + assert node.op == "output" + assert len(node.args) == 1, "FX.Node's args should have one arg" + node_args = node.args[0] + if isinstance(node_args, torch.fx.Node): + # For singleton tensor returns + self.graph_state.is_single_tensor_return = True + self.graph_state.outputs = [self.serialize_input(node_args)] + else: + assert isinstance(node_args, (tuple, list)) + self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args] + + def serialize_operator(self, target) -> str: + if isinstance(target, str): + return target + elif target.__module__.startswith("torch._ops"): + # TODO(zhxchen17) Maybe provide a function name helper in FX. + # From torch.fx.node._get_qualified_name + module = target.__module__.replace("torch._ops", "torch.ops") + return f"{module}.{target.__name__}" + else: # TODO(zhxchen17) Don't catch all here. + return f"{target.__module__}.{target.__name__}" + + def handle_call_function(self, node: torch.fx.Node): + assert node.op == "call_function" + + # getitem has been handled in the producer node, skip it here + if node.target is operator.getitem: + return + + meta_val = node.meta.get("val") + if ( + node.target in _SYM_INT_OPS + or node.target in _SYM_BOOL_OPS + or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool))) + ): + assert len(node.kwargs) == 0 + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[ + Argument.create( + as_sym_int=self.serialize_sym_int_output(node.name, meta_val) + ) + if (node.target in _SYM_INT_OPS or isinstance(meta_val, torch.SymInt)) + else Argument.create( + as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val) + ) + ], + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.OpOverload): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + # TODO: create a new tensor_values here, meta might have faketensor info + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.HigherOrderOperator): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(node.args, node.kwargs), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + ) + elif type(node.target) in _serialization_registry: + # Sanity check for unhandled serialization. + assert type(node.target) in _serialization_registry, f"{type(node.target)} is not supported in export serialization." + + handler = _serialization_registry[type(node.target)] + namespace = handler.namespace() + op_name = handler.to_op_name(node.target) + assert isinstance(namespace, str) and isinstance(op_name, str) + assert ":" not in namespace and ":" not in op_name + ex_node = Node( + target=f"#{namespace}:{op_name}", + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + metadata=self.serialize_metadata(node), + ) + else: + raise SerializeError(f"Serializing {node.target} is not supported") + + self.graph_state.nodes.append(ex_node) + + def handle_get_attr(self, node): + pass + + def _output_node_at_index(self, node, index) -> Optional[torch.fx.Node]: + user_node = None + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + if user_node is None: + user_node = user + else: + # We want to deduplicate getitem nodes that are trying to + # index to the same index + self.duplicate_getitem_nodes[user.name] = user_node.name + return user_node + + def _output_node_name_at_index(self, node, index) -> str: + user_node = self._output_node_at_index(node, index) + if user_node is None: + return f"{node.name}_unused_{index}" + else: + return user_node.name + + def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: + ret = {} + if stack_trace := node.meta.get("stack_trace"): + ret["stack_trace"] = stack_trace + + if nn_module_stack := node.meta.get("nn_module_stack"): + + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + + assert isinstance(path, str) + assert isinstance(ty, str) + + return path + "," + ty + + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) + + if source_fn_st := node.meta.get("source_fn_stack"): + source_fn_list = [ + f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" + for source_fn in source_fn_st + ] + ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) + + if torch_fn := node.meta.get("torch_fn"): + ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn)) + + if custom := node.meta.get("custom"): + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for node {node.name} with error {e}" + ) from e + + return ret + + def serialize_script_obj_meta( + self, script_obj_meta: ep.CustomObjArgument + ) -> CustomObjArgument: + return CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]: + if isinstance(op, torch._ops.OpOverload): + args_names = [arg.name for arg in op._schema.arguments] + else: + assert op in _SYM_INT_OPS or op in _SYM_BOOL_OPS + args_names = list(inspect.signature(op).parameters.keys()) + serialized_args = [] + for args_name, arg in zip(args_names, args): + serialized_args.append( + NamedArgument(name=args_name, arg=self.serialize_input(arg)) + ) + return serialized_args + + def serialize_inputs( + self, + target: Any, # torch._ops.OpOverload and other custom operator types. + args, + kwargs=None + ) -> List[NamedArgument]: + assert isinstance(target, (torch._ops.OpOverload, *_registered_extension_types())) + kwargs = kwargs or {} + serialized_args = [] + + schema = _get_schema_from_target(target) + + for i, schema_arg in enumerate(schema.arguments): + if schema_arg.name in kwargs: + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(kwargs[schema_arg.name], schema_arg.type), + ) + ) + elif not schema_arg.kwarg_only and i < len(args): + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(args[i], schema_arg.type), + ) + ) + else: + # We intentionally don't serialize the missing arguments + # with default values + pass + + return serialized_args + + def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: + """ + For serializing HOO inputs since HOOs do not have a schema. + """ + inputs = [ + NamedArgument( + name="", + arg=self.serialize_input(a), + ) + for a in args + ] + inputs.extend( + [ + NamedArgument(name=name, arg=self.serialize_input(a)) + for name, a in kwargs.items() + ] + ) + return inputs + + def is_sym_int_arg(self, arg) -> bool: + return isinstance(arg, int) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_int_values + ) + + def is_sym_bool_arg(self, arg) -> bool: + return isinstance(arg, bool) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_bool_values + ) + + def serialize_input( + self, arg, arg_type: Optional[torch._C.Argument] = None + ) -> Argument: + import torch._inductor.ir as inductor_ir + + inductor_tensor_buffers = ( + inductor_ir.Buffer, + inductor_ir.ReinterpretView, + ) + + if isinstance(arg, torch.fx.Node): + if arg.op == "get_attr": + assert isinstance(arg.target, str) + attr = getattr(arg.graph.owning_module, arg.target) + + if isinstance(attr, torch.Tensor): + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + elif isinstance(attr, torch.fx.GraphModule): + with self.save_graph_state(): + graph = self.serialize_graph(attr) + return Argument.create( + as_graph=GraphArgument(name=arg.target, graph=graph) + ) + else: + raise SerializeError( + f"Unsupported getattr attribute {arg.target} with type: {type(attr)}" + ) + elif self.is_sym_int_arg(arg): + return Argument.create( + as_sym_int=SymIntArgument.create(as_name=arg.name) + ) + elif self.is_sym_bool_arg(arg): + return Argument.create( + as_sym_bool=SymBoolArgument.create(as_name=arg.name) + ) + elif isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument( + name=arg.name, class_fqn=arg.meta["val"].class_fqn + ) + ) + elif arg.name in self.duplicate_getitem_nodes: + dedup_name = self.duplicate_getitem_nodes[arg.name] + return Argument.create(as_tensor=TensorArgument(name=dedup_name)) + else: + return Argument.create(as_tensor=TensorArgument(name=arg.name)) + elif isinstance(arg, inductor_tensor_buffers): + # Other branches are for arguments in fx node. + # This is a special branch for handling buffers (representing tensor arguments) + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + return Argument.create(as_tensor=TensorArgument(name=arg_name)) + elif isinstance(arg, torch.SymInt): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) + elif isinstance(arg, bool): + return Argument.create(as_bool=arg) + elif isinstance(arg, str): + return Argument.create(as_string=arg) + elif isinstance(arg, int): + return Argument.create(as_int=arg) + elif isinstance(arg, float): + return Argument.create(as_float=arg) + elif arg is None: + return Argument.create(as_none=()) + elif isinstance(arg, (list, tuple)): + if len(arg) == 0: + if arg_type is not None: + if isinstance(arg_type, torch.OptionalType): + arg_type = arg_type.getElementType() # type: ignore[assignment] + assert isinstance(arg_type, torch.ListType) + elem_type = arg_type.getElementType() + if isinstance(elem_type, torch.OptionalType): + elem_type = elem_type.getElementType() + + if isinstance(elem_type, torch.BoolType): + return Argument.create(as_bools=[]) + elif isinstance(elem_type, torch.IntType): + return Argument.create(as_ints=[]) + elif isinstance(elem_type, torch.FloatType): + return Argument.create(as_floats=[]) + elif isinstance(elem_type, torch.StringType): + return Argument.create(as_strings=[]) + elif isinstance(elem_type, torch.TensorType): + return Argument.create(as_tensors=[]) + else: + # I believe empty symint lists default to ints, but + # please file an issue if this is not the case + raise SerializeError(f"Empty list with type {elem_type} nyi.") + else: + # We could serialize this by default to a tensor list. This + # is needed in the HOO case + log.warning( + "Unsure how to serialize the given empty list, " + "as we don't know what is the type of this argument. " + "Serializing it as a tensor list by default." + ) + return Argument.create(as_tensors=[]) + + # Must check bool first, as bool is also treated as int + if all(isinstance(a, bool) for a in arg): + return Argument.create(as_bools=list(arg)) + elif all(isinstance(a, int) for a in arg): + return Argument.create(as_ints=list(arg)) + elif all(isinstance(a, float) for a in arg): + return Argument.create(as_floats=list(arg)) + elif all(isinstance(a, str) for a in arg): + return Argument.create(as_strings=list(arg)) + elif all(isinstance(a, torch.SymInt) for a in arg): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create( + as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg] + ) + elif all(self.is_sym_int_arg(a) for a in arg): + # list of sym_ints + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymIntArgument.create(as_name=a.name)) + elif isinstance(a, int): + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(self.is_sym_bool_arg(a) for a in arg): + # list of sym_bools + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymBoolArgument.create(as_name=a.name)) + elif isinstance(a, bool): + values.append(SymBoolArgument.create(as_bool=a)) + return Argument.create(as_sym_bools=values) + elif all(isinstance(a, torch.fx.Node) for a in arg): + # list of tensors + arguments = [] + for a in arg: + if a.op == "get_attr": + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + arguments.append(TensorArgument(name=a.name)) + return Argument.create(as_tensors=arguments) + elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): + # list of optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=()) + elif isinstance(a, torch.fx.Node): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.name) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all(isinstance(a, inductor_tensor_buffers) for a in arg): + # list of inductor buffers + return Argument.create( + as_tensors=[TensorArgument(name=a.get_name()) for a in arg], + ) + elif all( + isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg + ): + # list of inductor buffers as optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=()) + elif isinstance(a, inductor_tensor_buffers): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.get_name()) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + else: + raise SerializeError( + f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" + ) + elif isinstance(arg, torch.dtype): + return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) + elif isinstance(arg, torch.device): + return Argument.create(as_device=Device(type=arg.type, index=arg.index)) + elif isinstance(arg, torch.memory_format): + return Argument.create( + as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg] + ) + elif isinstance(arg, torch.layout): + return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) + elif isinstance(arg, torch._C.ScriptObject): + if not ( + arg._has_method("__getstate__") # type: ignore[attr-defined] + and arg._has_method("__setstate__") # type: ignore[attr-defined] + ): + raise SerializeError( + f"Unable to serialize custom class {arg}. Please define " + "serialization methods via def_pickle()." + ) + # Custom objects through torchind are serializable with pickle, + # through implementing the .def_pickle function. This should result + # in the object containing a __getstate__ and __setstate__ + # serialize/deserialize function. + custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" + self.custom_objs[custom_obj_name] = arg + class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] + return Argument.create( + as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn) + ) + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return Argument.create(as_operator=self.serialize_operator(arg)) + else: + raise SerializeError(f"Unsupported argument type: {type(arg)} with schema arg_type {arg_type}") + + def serialize_tensor_output(self, name, meta_val) -> TensorArgument: + assert name not in self.graph_state.tensor_values + self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val) + return TensorArgument(name=name) + + def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_int_values + self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val) + return SymIntArgument.create(as_name=name) + + def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_bool_values + self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) + return SymBoolArgument.create(as_name=name) + + def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: + if spec.kind == ep.InputKind.USER_INPUT: + if isinstance(spec.arg, ep.ConstantArgument): + if isinstance(spec.arg.value, int): + constant_spec = ConstantValue.create(as_int=spec.arg.value) + elif isinstance(spec.arg.value, bool): + constant_spec = ConstantValue.create(as_bool=spec.arg.value) + elif isinstance(spec.arg.value, str): + constant_spec = ConstantValue.create(as_string=spec.arg.value) + elif isinstance(spec.arg.value, float): + constant_spec = ConstantValue.create(as_float=spec.arg.value) + elif spec.arg.value is None: + constant_spec = ConstantValue.create(as_none=()) + else: + raise SerializeError(f"Unhandled constant input {spec.arg.value} to serialize") + return InputSpec.create( + constant_input=ConstantInputSpec( + name=spec.arg.name, value=constant_spec + ) + ) + else: + return InputSpec.create( + user_input=UserInputSpec( + arg=self.serialize_argument_spec(spec.arg) + ) + ) + elif spec.kind == ep.InputKind.PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + parameter=InputToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.BUFFER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + assert spec.persistent is not None + return InputSpec.create( + buffer=InputToBufferSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + persistent=spec.persistent, + ) + ) + elif spec.kind == ep.InputKind.CONSTANT_TENSOR: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + tensor_constant=InputToTensorConstantSpec( + arg=TensorArgument(name=spec.arg.name), + tensor_constant_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.CUSTOM_OBJ: + assert spec.target is not None + assert isinstance(spec.arg, ep.CustomObjArgument) + return InputSpec.create( + custom_obj=InputToCustomObjSpec( + arg=CustomObjArgument( + name=spec.arg.name, class_fqn=spec.arg.class_fqn + ), + custom_obj_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return InputSpec.create( + token=InputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: + if spec.kind == ep.OutputKind.USER_OUTPUT: + return OutputSpec.create( + user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg)) + ) + elif spec.kind == ep.OutputKind.LOSS_OUTPUT: + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name)) + ) + elif spec.kind == ep.OutputKind.BUFFER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + buffer_mutation=BufferMutationSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_parameter=GradientToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_user_input=GradientToUserInputSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + user_input_mutation=UserInputMutationSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return OutputSpec.create( + token=OutputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: + return GraphSignature( + input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], + output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], + ) + + def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: + if isinstance(x, ep.TensorArgument): + return Argument.create(as_tensor=TensorArgument(name=x.name)) + elif isinstance(x, ep.SymIntArgument): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) + elif isinstance(x, ep.ConstantArgument): + return self.serialize_input(x.value) + elif isinstance(x, ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn) + ) + else: + raise AssertionError("TODO") + + def serialize_module_call_signature( + self, module_call_signature: ep.ModuleCallSignature + ) -> ModuleCallSignature: + return ModuleCallSignature( + inputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION), + out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION), + ) + + def serialize_module_call_graph( + self, module_call_graph: List[ep.ModuleCallEntry] + ) -> List[ModuleCallEntry]: + return [ + ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.serialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: + """For a given node, return the dataclass representing its output values. + + [NOTE: Multiple outputs] We handle aggregates differently than FX. For + FX, it looks like: + + x = call_function("multiple_return", ...) + element0 = call_function(getitem, x, 0) + foo = call_function("use_output", element0) + + We do not want the intermediate `getitem` call, so our serialized thing looks like: + + element0, element1, element2 = call_function("multiple_return", ...) + foo = call_function("use_output", element0) + + We want names to be consistent across these two schemes, so that we can + mostly reuse the names coming from FX. This function computes a mapping from + the FX representation to our representation, preserving the names. + """ + assert node.op == "call_function" and isinstance(node.target, (torch._ops.OpOverload, *_registered_extension_types())) + + schema = _get_schema_from_target(node.target) + returns = schema.returns + + if len(returns) == 0: + return [] + + meta_val = node.meta["val"] + + # Check single value return + if _is_single_tensor_list_return(node.target): + # e.g "-> Tensor[]" + tensor_args = [] + for idx, meta in enumerate(meta_val): + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + elif len(returns) == 1: + return [self.serialize_output(node.name, meta_val)] + + # There are a two possibilities at this point: + # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)" + # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])" + # + # Either way, start by gathering a list of TensorArguments with the correct names. + # For consistent naming with FX, consult the downstream `getitem` node and + # make sure our outputs have the same name. + + output_arguments = [] + for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): + if meta is None: + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) + # When the return type is annoated as Tensor type, the op can also return an + # undefined Tensor which will be implicitly converted to None in Python. + output_arguments.append(Argument.create(as_none=())) + elif isinstance(meta, FakeTensor): + assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType)) + name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(name, meta)) + elif isinstance(meta, list): + # for List[Tensor] return type + assert isinstance( + return_schema.real_type, torch.ListType + ) and isinstance( + return_schema.real_type.getElementType(), torch.TensorType + ) + user_node = self._output_node_at_index(node, idx) + assert user_node is not None + + args = [] + for i, m in enumerate(meta): + if m is None: + continue + sub_user_node_name = self._output_node_name_at_index(user_node, i) + args.append(self.serialize_tensor_output(sub_user_node_name, m)) + output_arguments.append(Argument.create(as_tensors=args)) + elif isinstance(meta, (int, SymInt)): + user_node_name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(user_node_name, meta)) + else: + raise ValueError( + f"Unhandled output type {type(meta)} from node {node.format_node()}" + ) + + return output_arguments + + def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: + """ + For serializing HOO outputs since HOOs do not have a schema. + """ + meta_val = node.meta["val"] + + if isinstance(meta_val, tuple): + # Note: Since we don't have a schema, we just serialize all tuple + # outputs to be a list of values. Even if the output is supposed to + # be a tensor list (Tensor[]), we will serialize it to be a list of + # tensors (Tensor, Tensor, Tensor). An exception is that if there's + # a singleton tensor, we will serialize this to be a singleton + # tensor list so that the deserializer knows to insert getitem nodes. + + if len(meta_val) == 1: + assert isinstance(meta_val[0], torch.Tensor) + name = self._output_node_name_at_index(node, 0) + return [Argument.create(as_tensors=[self.serialize_tensor_output(name, meta_val[0])])] + + outputs = [] + for i, element_meta_val in enumerate(meta_val): + user_node = self._output_node_at_index(node, i) + if isinstance(element_meta_val, list): + # e.g "-> Tensor[]" + assert user_node is not None + + tensors = [] + for j, m in enumerate(element_meta_val): + if not isinstance(m, torch.Tensor): + raise SerializeError(f"Serialize list output with type {type(m)} nyi") + + name = self._output_node_name_at_index(user_node, j) + tensors.append(self.serialize_tensor_output(name, m)) + outputs.append(Argument.create(as_tensors=tensors)) + + else: + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{i}" + ) + + outputs.append(self.serialize_output(name, element_meta_val)) + + return outputs + else: + return [self.serialize_output(node.name, meta_val)] + + def serialize_output(self, name: str, meta_val: Any) -> Argument: + # Check single value return + if meta_val is None: + return Argument.create(as_none=()) + if isinstance(meta_val, torch.Tensor): + # e.g "-> Tensor" + return Argument.create( + as_tensor=self.serialize_tensor_output(name, meta_val) + ) + elif isinstance(meta_val, (int, torch.SymInt)): + # e.g "-> SymInt" + return Argument.create( + as_sym_int=self.serialize_sym_int_output(name, meta_val) + ) + elif isinstance(meta_val, torch.SymBool): + # e.g "-> SymBool" + return Argument.create( + as_sym_bool=self.serialize_sym_bool_output(name, meta_val) + ) + + # list outputs should've been handled earlier + raise SerializeError(f"Unable to serialize output {meta_val}") + + def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]: + meta_val = node.meta["val"] + + idx_to_name = {} + for user in node.users: + assert ( + user.target is operator.getitem + ), f"User node {user} of {node} is incorrect" + idx_to_name[user.args[1]] = user.name + + for idx, _ in enumerate(meta_val): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + arg_list = [] + for i, element_meta_val in enumerate(meta_val): + arg_list.append( + self.serialize_tensor_output(idx_to_name[i], element_meta_val) + ) + + return arg_list + + def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: + assert isinstance(graph_module, torch.fx.GraphModule) + for node in graph_module.graph.nodes: + try: + getattr(self, f"handle_{node.op}")(node) + except Exception as e: + raise SerializeError( + f"Failed serializing node {node} in graph: {node.format_node()}\n Original exception {traceback.format_exc()}" + ) from e + + return Graph( + inputs=self.graph_state.inputs, + nodes=self.graph_state.nodes, + tensor_values=self.graph_state.tensor_values, + sym_int_values=self.graph_state.sym_int_values, + sym_bool_values=self.graph_state.sym_bool_values, + custom_obj_values=self.graph_state.custom_obj_values, + outputs=self.graph_state.outputs, + is_single_tensor_return=self.graph_state.is_single_tensor_return, + ) + + def serialize_graph_module_metadata(self, meta: Dict[str, Any]): + ret = {} + if custom := meta.get("custom"): + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for graph with error {e}" + ) from e + + return ret + + def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule: + graph = self.serialize_graph(graph_module) + + return GraphModule( + graph=graph, + signature=self.serialize_signature(self.graph_signature), + module_call_graph=self.serialize_module_call_graph(self.module_call_graph), + metadata=self.serialize_graph_module_metadata(graph_module.meta) + ) + + +@final +class ExportedProgramSerializer(metaclass=Final): + def __init__(self, opset_version: Optional[Dict[str, int]] = None): + self.opset_version: Dict[str, int] = {} + if opset_version: + self.opset_version.update(opset_version) + if "aten" not in self.opset_version: + self.opset_version["aten"] = torch._C._get_max_operator_version() + + def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: + """ + Args: + exported_program: Exported Program to serialize + """ + exported_program.validate() + + gm_serializer = GraphModuleSerializer( + exported_program.graph_signature, exported_program.module_call_graph + ) + serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) + serialized_range_constraints = serialize_range_constraints( + exported_program.range_constraints + ) + + # TODO: Directly serialize exported_program.constants once + # CustomClassHolders get stored in the ExportedProgram rather than in + # the graph + constants = {} + for n, c in gm_serializer.custom_objs.items(): + constants[n] = c + for n, t in exported_program.constants.items(): + assert n not in constants + constants[n] = t + + serialized_ep = ExportedProgram( + graph_module=serialized_graph_module, + opset_version=self.opset_version, + range_constraints=serialized_range_constraints, + schema_version=SchemaVersion( + major=SCHEMA_VERSION[0], + minor=SCHEMA_VERSION[1], + ), + verifiers=[v.dialect for v in exported_program.verifiers], + torch_version=torch.__version__, + ) + + # Test canonical form is well defined. + canonicalize(serialized_ep) + + # Proxy cannot be dumped, so we remove them. + new_state_dict = remove_proxy_from_state_dict( + exported_program.state_dict, in_place=False + ) + return _SerializedProgram( + serialized_ep, + serialize_torch_artifact(new_state_dict), + serialize_torch_artifact(constants), + serialize_torch_artifact(exported_program.example_inputs), + ) + + +@final +class GraphModuleDeserializer(metaclass=Final): + @dataclasses.dataclass + class Result: + graph_module: torch.fx.GraphModule + signature: ep.ExportGraphSignature + module_call_graph: List[ep.ModuleCallEntry] + names_to_symbols: Dict[str, sympy.Symbol] + state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]] + example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]] + + def __init__(self) -> None: + self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} + self.serialized_name_to_meta: Dict[str, MetaType] = {} + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + + @contextmanager + def save_graph_module(self) -> Iterator[None]: + saved = ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + ) + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + self.serialized_name_to_node = {} + self.serialized_name_to_meta = {} + try: + yield + finally: + ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + ) = saved + + def deserialize_extension_operator(self, serialized_target: str): + namespace, op_name = serialized_target.split(":") + namespace = namespace[1:] # starting with # + handler = _deserialization_registry[namespace] + return handler.from_op_name(op_name) + + def deserialize_operator(self, serialized_target: str): + if serialized_target.startswith( + "_operator" + ): # TODO(zhxchen17) Follow up on this. + module = operator + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("torch"): + module = torch # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("#"): + return self.deserialize_extension_operator(serialized_target) + else: # TODO(zhxchen17) Don't catch all here. + return serialized_target + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: + val = s.value + if s.type == "as_expr": + if val.hint is None: + hint = None + else: + assert val.hint.type == "as_int" + hint = val.hint.value + + if val.expr_str in self.symbol_name_to_symbol: + sym = self.symbol_name_to_symbol[val.expr_str] + else: + sym = sympy.sympify( + val.expr_str, + locals={**self.sympy_functions, **self.symbol_name_to_symbol}, + ) + # NOTE(avik): Assumptions on symbols are not explicitly serialized. + # This seems dangerous: it might cause unknown differences in shape env behavior + # on deserialization? Probably deserves a follow-up. + + # Here we force symbols corresponding to SymInts to be at least integers. + # Otherwise some expressions that the shape env would otherwise evaluate to False, + # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) + sym = sym.subs( + {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} + ) + # We need to check if the symbol has already been allocated, + # self.symbol_name_to_symbol is not enough because the + # integer-ification of symbols can induce simplification; + # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral + if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: + self.symbol_name_to_symbol[val.expr_str] = sym + if hint is not None: + self.shape_env.add_var_to_val(sym, hint) + + if vr := self.symbol_name_to_range.get(val.expr_str): + self.shape_env.constrain_symbol_range( + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + else: + # Placeholders, in particular, can have shapes as symbolic expressions. + # We need to populate the shape env with the range constraints of their + # free symbols, otherwise evaluating such expressions will error. + self.symbol_name_to_symbol[val.expr_str] = sym + free_symbols = sym.free_symbols + for s in free_symbols: + if s.name not in self.symbol_name_to_symbol: + self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] + if vr := self.symbol_name_to_range.get(s.name): + self.shape_env.constrain_symbol_range( + s, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + + return self.shape_env.create_symintnode(sym, hint=hint) + elif s.type == "as_int": + assert isinstance(val, int) + return val + else: + raise SerializeError( + f"SymInt has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: + val = s.value + if s.type == "as_expr": + # first we sympify this just to access any untracked symbols + expr = sympy.sympify(val.expr_str) + for sym in expr.free_symbols: + if ( + not isinstance(sym, sympy.Number) + and str(sym) not in self.symbol_name_to_symbol + ): + self.deserialize_sym_int(SymInt.create(as_expr=SymExpr(str(sym)))) + # then we sympify again using locals to correctly reify with the constructed symbols + expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) + return self.shape_env.create_symboolnode(expr) + elif s.type == "as_bool": + assert isinstance(val, bool) + return val + else: + raise SerializeError( + f"SymBool has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_tensor_meta( + self, + tensor_meta: TensorMeta, + ) -> FakeTensor: + with self.fake_tensor_mode: + return cast( + FakeTensor, + torch.empty_strided( + tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc] + tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc] + device=deserialize_device(tensor_meta.device), + dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], + ), + ) + + def deserialize_script_obj_meta( + self, script_obj_meta: CustomObjArgument + ) -> ep.CustomObjArgument: + return ep.CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]: + if output.type == "as_tensor": + return self.serialized_name_to_node[output.as_tensor.name] + elif output.type == "as_sym_int": + return self.serialized_name_to_node[output.as_sym_int.as_name] + elif output.type == "as_sym_bool": + return self.serialized_name_to_node[output.as_sym_bool.as_name] + elif output.type == "as_int": + return output.as_int + elif output.type == "as_none": + return None + else: + raise SerializeError(f"Unable to deserialize output node {output}") + + def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: + # Handle the tensor metas. + for name, tensor_value in serialized_graph.tensor_values.items(): + meta_val = self.deserialize_tensor_meta(tensor_value) + self.serialized_name_to_meta[name] = meta_val + + for name, sym_int_value in serialized_graph.sym_int_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value) + + for name, sym_bool_value in serialized_graph.sym_bool_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_sym_bool( + sym_bool_value + ) + + for name, script_obj_meta in serialized_graph.custom_obj_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta( + script_obj_meta + ) + + # Inputs: convert to placeholder nodes in FX. + for i, input_ in enumerate(serialized_graph.inputs): + if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"): + node_name = input_.value.name + placeholder_node = self.graph.placeholder(node_name) + # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) + # we will overwrite it + placeholder_node.name = node_name + self.sync_fx_node(node_name, placeholder_node) + elif input_.type in ( + "as_int", + "as_float", + "as_bool", + "as_none", + "as_string", + ): + node_name = self.signature.input_specs[i].arg.name + placeholder_node = self.graph.placeholder(node_name) + placeholder_node.meta["val"] = self.deserialize_input(input_) + else: + raise SerializeError(f"Invalid input type {input_}") + + # Nodes: convert to call_function nodes. + for serialized_node in serialized_graph.nodes: + try: + target = self.deserialize_operator(serialized_node.target) + self.deserialize_node(serialized_node, target) + + except Exception as e: + raise SerializeError( + f"Failed deserializing node {serialized_node}\n Original exception {traceback.format_exc()}" + ) from e + + # Outputs: convert to a single `output` node. + outputs = [] + for output in serialized_graph.outputs: + outputs.append(self.deserialize_graph_output(output)) + + if serialized_graph.is_single_tensor_return: + assert len(outputs) == 1 + outputs = outputs[0] # type: ignore[assignment] + else: + outputs = tuple(outputs) # type: ignore[assignment] + + output_node = self.graph.output(outputs) + + if serialized_graph.is_single_tensor_return: + output_node.meta["val"] = output_node.args[0].meta["val"] + else: + output_node.meta["val"] = tuple( + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ) + + return self.graph + + def deserialize_node(self, serialized_node: Node, target: Callable) -> None: + if ( + target in _SYM_BOOL_OPS + or target in _SYM_INT_OPS + or target == torch.ops.aten.item.default # this can produce either SymInt or SymBool + ): + name = serialized_node.outputs[0].value.as_name + args = self.deserialize_sym_op_inputs(serialized_node.inputs) + + fx_node = self.graph.create_node("call_function", target, args, {}, name) + self.deserialize_sym_op_outputs(serialized_node, fx_node) + + elif isinstance(target, torch._ops.HigherOrderOperator): + args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) + # If HOP returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + # + # HOPs don't have schema yet, just check the output lengths and as_tensor attribute + name = ( + serialized_node.outputs[0].as_tensor.name + if len(serialized_node.outputs) == 1 + and hasattr(serialized_node.outputs[0], "as_tensor") + else None + ) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + + elif isinstance(target, (torch._ops.OpOverload, *_registered_extension_types())): + # For convenience: if this node returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + name = ( + serialized_node.outputs[0].as_tensor.name + if _is_single_tensor_return(target) + else None # FX will generate a name for us. + ) + args, kwargs = self.deserialize_inputs(target, serialized_node) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + else: + raise SerializeError( + f"Unsupported target type for node {serialized_node}: {type(target)}" + ) + + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + if fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta: + fx_node.meta["nn_module_stack"] = {} # serialization throws away empty dicts + + def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: + if i.type == "user_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=self.deserialize_argument_spec(i.user_input.arg), + target=None, + ) + elif i.type == "parameter": + return ep.InputSpec( + kind=ep.InputKind.PARAMETER, + arg=ep.TensorArgument(name=i.parameter.arg.name), + target=i.parameter.parameter_name, + ) + elif i.type == "buffer": + return ep.InputSpec( + kind=ep.InputKind.BUFFER, + arg=ep.TensorArgument(name=i.buffer.arg.name), + target=i.buffer.buffer_name, + persistent=i.buffer.persistent, + ) + elif i.type == "tensor_constant": + return ep.InputSpec( + kind=ep.InputKind.CONSTANT_TENSOR, + arg=ep.TensorArgument(name=i.tensor_constant.arg.name), + target=i.tensor_constant.tensor_constant_name, + ) + elif i.type == "custom_obj": + return ep.InputSpec( + kind=ep.InputKind.CUSTOM_OBJ, + arg=ep.CustomObjArgument( + name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn + ), + target=i.custom_obj.custom_obj_name, + ) + elif i.type == "token": + return ep.InputSpec( + kind=ep.InputKind.TOKEN, + arg=ep.TokenArgument(name=i.token.arg.name), + target=None + ) + elif i.type == "constant_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=ep.ConstantArgument( + name=i.constant_input.name, + value=self.deserialize_constant_input(i.constant_input.value) + ), + target=None, + ) + else: + raise AssertionError(f"Unknown input spec {i}") + + def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: + if o.type == "user_output": + return ep.OutputSpec( + kind=ep.OutputKind.USER_OUTPUT, + arg=self.deserialize_argument_spec(o.user_output.arg), + target=None, + ) + elif o.type == "loss_output": + return ep.OutputSpec( + kind=ep.OutputKind.LOSS_OUTPUT, + arg=ep.TensorArgument(name=o.loss_output.arg.name), + target=None, + ) + elif o.type == "buffer_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.BUFFER_MUTATION, + arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), + target=o.buffer_mutation.buffer_name, + ) + elif o.type == "gradient_to_parameter": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_PARAMETER, + arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), + target=o.gradient_to_parameter.parameter_name, + ) + elif o.type == "gradient_to_user_input": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, + arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), + target=o.gradient_to_user_input.user_input_name, + ) + elif o.type == "user_input_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.USER_INPUT_MUTATION, + arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), + target=o.user_input_mutation.user_input_name, + ) + elif o.type == "token": + return ep.OutputSpec( + kind=ep.OutputKind.TOKEN, + arg=ep.TokenArgument(name=o.token.arg.name), + target=None + ) + else: + raise AssertionError(f"Unknown output spec {o}") + + def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: + return ep.ExportGraphSignature( + input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs], + ) + + def deserialize( + self, + serialized_graph_module: GraphModule, + serialized_state_dict: Union[Dict[str, torch.Tensor], bytes], + constants: Union[Dict[str, Any], bytes], + example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None, + symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, + ) -> Result: + global _CURRENT_DESERIALIZER + assert _CURRENT_DESERIALIZER is None + _CURRENT_DESERIALIZER = self + try: + self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.sympy_functions = { + # all torch.utils._sympy.functions should go here + # TODO(avik): find a better way to keep this collection in sync; + # e.g.., `exec('from torch.utils._sympy.functions import *', ...)` + # would work as long as the public API of that module is complete + "FloorDiv": torch.utils._sympy.functions.FloorDiv, + "ModularIndexing": torch.utils._sympy.functions.ModularIndexing, + "Where": torch.utils._sympy.functions.Where, + "PythonMod": torch.utils._sympy.functions.PythonMod, + "Mod": torch.utils._sympy.functions.Mod, + "CleanDiv": torch.utils._sympy.functions.CleanDiv, + "CeilToInt": torch.utils._sympy.functions.CeilToInt, + "FloorToInt": torch.utils._sympy.functions.FloorToInt, + "CeilDiv": torch.utils._sympy.functions.CeilDiv, + "LShift": torch.utils._sympy.functions.LShift, + "RShift": torch.utils._sympy.functions.RShift, + "PowByNatural": torch.utils._sympy.functions.PowByNatural, + "FloatPow": torch.utils._sympy.functions.FloatPow, + "FloatTrueDiv": torch.utils._sympy.functions.FloatTrueDiv, + "IntTrueDiv": torch.utils._sympy.functions.IntTrueDiv, + "IsNonOverlappingAndDenseIndicator": torch.utils._sympy.functions.IsNonOverlappingAndDenseIndicator, + "TruncToFloat": torch.utils._sympy.functions.TruncToFloat, + "TruncToInt": torch.utils._sympy.functions.TruncToInt, + "RoundToInt": torch.utils._sympy.functions.RoundToInt, + "RoundDecimal": torch.utils._sympy.functions.RoundDecimal, + "ToFloat": torch.utils._sympy.functions.ToFloat, + "Identity": torch.utils._sympy.functions.Identity, + } + self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} + self.constants = deserialize_torch_artifact(constants) + self.signature = self.deserialize_signature(serialized_graph_module.signature) + + # deserialization does analysis with checks on 0/1, so we create fake range constraints and + # restore the original range constraints afterwards + self.symbol_name_to_range = {} + if symbol_name_to_range: + for k, vr in symbol_name_to_range.items(): + lower = vr.lower + if vr.upper >= 2: # max is >= 2, not sym bool range + lower = max(2, lower) + self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper) + + if example_inputs is not None and len(example_inputs) > 0: + self.example_inputs = deserialize_torch_artifact(example_inputs) + else: + self.example_inputs = None + self.deserialize_graph(serialized_graph_module.graph) + + module_call_graph = self.deserialize_module_call_graph( + serialized_graph_module.module_call_graph + ) + graph_module = ep._create_graph_module_for_export( + self.module, self.graph + ) + meta = {} + if custom := serialized_graph_module.metadata.get("custom"): + meta["custom"] = json.loads(custom) + graph_module.meta = meta + return GraphModuleDeserializer.Result( + graph_module=graph_module, + signature=self.signature, + module_call_graph=module_call_graph, + names_to_symbols=self.symbol_name_to_symbol, + state_dict=deserialize_torch_artifact(serialized_state_dict), + constants=self.constants, + example_inputs=self.example_inputs, + ) + finally: + _CURRENT_DESERIALIZER = None + + def sync_fx_node(self, name: str, fx_node: torch.fx.Node): + if name in self.serialized_name_to_node: + raise SerializeError(f"Node {name} has already been deserialized before.") + # overwrite name + fx_node.name = name + self.serialized_name_to_node[name] = fx_node + assert "val" not in fx_node.meta + fx_node.meta["val"] = self.serialized_name_to_meta[name] + + def deserialize_sym_op_inputs(self, inputs): + return tuple(self.deserialize_input(input.arg) for input in inputs) + + def deserialize_inputs(self, target, serialized_node: Node): + schema_args = _get_schema_from_target(target).arguments + actual_args = { + input.name: self.deserialize_input(input.arg) + for input in serialized_node.inputs + } + args = [] + kwargs = {} + for schema_arg in schema_args: + is_positional = ( + not schema_arg.has_default_value() and not schema_arg.kwarg_only + ) + if is_positional: + args.append(actual_args[schema_arg.name]) + else: + if schema_arg.name in actual_args: + kwargs[schema_arg.name] = actual_args[schema_arg.name] + return tuple(args), kwargs + + def deserialize_hoo_inputs(self, inputs: List[NamedArgument]): + """ + For deserializing HOO inputs since HOOs do not have a schema. + """ + args = [] + kwargs = {} + for input_ in inputs: + if input_.name != "": + kwargs[input_.name] = self.deserialize_input(input_.arg) + else: + args.append(self.deserialize_input(input_.arg)) + return (tuple(args), kwargs) + + def deserialize_input(self, inp: Argument) -> Any: + value = inp.value + typ_ = inp.type + if typ_ == "as_none": + # None should converted as None, but is encoded as bool in serialized + # Convert serialized object to torch equivalent + return None + elif typ_ == "as_tensor": + return self.serialized_name_to_node[inp.as_tensor.name] + elif typ_ == "as_scalar_type": + return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] + elif typ_ == "as_memory_format": + return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] + elif typ_ == "as_layout": + return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] + elif typ_ == "as_graph": + assert isinstance(value, GraphArgument) + with self.save_graph_module(): + self.deserialize_graph(value.graph) + submodule = ep._create_graph_module_for_export(self.module, self.graph) + self.module.register_module(value.name, submodule) + return self.graph.create_node( + "get_attr", + value.name, + name=value.name, + ) + elif typ_ == "as_device": + return deserialize_device(inp.as_device) + elif typ_ == "as_int": + return inp.as_int + elif typ_ == "as_float": + return inp.as_float + elif typ_ == "as_bool": + return inp.as_bool + elif typ_ == "as_string": + return inp.as_string + elif typ_ == "as_sym_int": + return self.deserialize_sym_argument(inp.as_sym_int) + elif typ_ == "as_sym_bool": + return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, list): + if len(value) == 0: + return [] + elif typ_ == "as_tensors": + result = [] + for arg in value: + result.append(self.serialized_name_to_node[arg.name]) + return result + elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): + # convert from serialized.python.types.List to python list + return list(value) + elif typ_ in ("as_sym_ints", "as_sym_bools"): + return [self.deserialize_sym_argument(arg) for arg in value] + elif typ_ == "as_optional_tensors": + + def deserialize_optional_tensor_args(a): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return self.serialized_name_to_node[a.value.name] + else: + raise SerializeError(f"Unhandled argument {inp}") + + return list(map(deserialize_optional_tensor_args, value)) + else: + raise SerializeError(f"Unhandled argument {inp}") + elif typ_ == "as_custom_obj": + if inp.as_custom_obj.name in self.serialized_name_to_node: + # Custom object has been lifted as an input + return self.serialized_name_to_node[inp.as_custom_obj.name] + return self.constants[inp.as_custom_obj.name] + elif typ_ == "as_operator": + return self.deserialize_operator(inp.as_operator) + else: + raise SerializeError(f"Unhandled argument {inp}") + + def deserialize_constant_input(self, inp: ConstantValue) -> Any: + if inp.type == "as_int": + return int(inp.as_int) + elif inp.type == "as_float": + return float(inp.as_float) + elif inp.type == "as_string": + return str(inp.as_string) + elif inp.type == "as_bool": + return bool(inp.as_bool) + elif inp.type == "as_none": + return None + else: + raise SerializeError(f"Unhandled constant argument {inp} to deserialize") + + def deserialize_sym_argument(self, sym_arg): + if isinstance(sym_arg, SymIntArgument): + if sym_arg.type == "as_int": + return sym_arg.as_int + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymBoolArgument): + if sym_arg.type == "as_bool": + return sym_arg.as_bool + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + raise SerializeError(f"Unknown symbolic argument type: {sym_arg}") + + def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + + def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + # Check single value return + if len(serialized_node.outputs) == 0: + return + if ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_tensor" + ): + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + return + elif len(serialized_node.outputs) == 1 and isinstance( + serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument) + ): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + return + + self.deserialize_multiple_outputs(serialized_node, fx_node) + + def deserialize_multiple_outputs( + self, serialized_node: Node, fx_node: torch.fx.Node + ) -> None: + deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) + + def generate_getitem( + meta_val, + fx_node: torch.fx.Node, + arg: Union[TensorArgument, SymIntArgument], + idx: int, + ): + if isinstance(arg, TensorArgument): + name = arg.name + elif isinstance(arg, SymIntArgument): + name = arg.as_name + else: + raise AssertionError( + f"generate_getitem got unknown argument type {type(arg)}" + ) + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name=name, + ) + self.sync_fx_node(name, individual_output) + meta_val.append(self.serialized_name_to_meta[name]) + # The derived `getitem` nodes should have the same stacktrace as the + # original `fx_node` + individual_output.meta.update(deserialized_metadata) + + def generate_getitems(meta_val, fx_node: torch.fx.Node, args): + for idx, arg in enumerate(args): + if isinstance(arg, Argument): + arg = arg.value + if isinstance(arg, (TensorArgument, SymIntArgument)): + generate_getitem(meta_val, fx_node, arg, idx) + elif isinstance(arg, (list, tuple)): + list_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + ) + meta_val.append([]) + generate_getitems(meta_val[-1], list_output, arg) + list_output.meta.update(deserialized_metadata) + list_output.meta["val"] = meta_val[-1] + else: + raise NotImplementedError(f"Unimplemented node output type: {arg}") + + # Convert multiple return types to FX format. + # In FX, each node only returns one value. So in order to represent + # multiple return values, we have to emit a `getitem` node for each + # return value. + # This performs the inverse mapping of the `serialize_outputs` call in + # serialization, see [NOTE: Multiple outputs] + meta_val: List[Any] = [] + if len(serialized_node.outputs) == 1: + assert isinstance(serialized_node.outputs[0].value, list) + assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) + generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors) + else: + generate_getitems(meta_val, fx_node, serialized_node.outputs) + + # also update the metaval for `fx_node` to be a list(meta) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + + def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: + ret: Dict[str, Any] = {} + if stack_trace := metadata.get("stack_trace"): + ret["stack_trace"] = stack_trace + + def deserialize_meta_func(serialized_target: str): + module = None + if serialized_target.startswith("torch.nn"): + module = torch.nn + serialized_target_names = serialized_target.split(".")[2:] + elif serialized_target.startswith("torch"): + module = torch + serialized_target_names = serialized_target.split(".")[1:] + else: + return self.deserialize_operator(serialized_target) + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + if nn_module_stack_str := metadata.get("nn_module_stack"): + # Originally serialized to "key,orig_path,type_str" + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + + # Helper function that splits strings by commas except for those + # encapsulated by parens, which are valid traces. + # TODO: Currently this is needed due to indexing Sequential + # layers introducing names in the form "layer.slice(1, None, None)". + # If that naming is improved, this fancier splitting can probably be + # reverted to a simple split by comma. + def metadata_split(metadata): + # Remove the parentheses and commas inside them + metadata = re.sub(r'\(.*?\)', '', metadata) + # Split the string by comma, except for those inside parentheses + return re.split(r'(? ep.ArgumentSpec: + if x.type == "as_tensor": + return ep.TensorArgument(name=x.as_tensor.name) + elif x.type == "as_sym_int": + return ep.SymIntArgument(name=x.as_sym_int.as_name) + elif x.type == "as_custom_obj": + return ep.ConstantArgument(name=x.as_custom_obj.name, value=self.deserialize_input(x)) + else: + return ep.ConstantArgument(name="", value=self.deserialize_input(x)) + + def deserialize_module_call_signature( + self, module_call_signature: ModuleCallSignature + ) -> ep.ModuleCallSignature: + return ep.ModuleCallSignature( + inputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=treespec_loads(module_call_signature.in_spec), + out_spec=treespec_loads(module_call_signature.out_spec), + ) + + def deserialize_module_call_graph( + self, module_call_graph: List[ModuleCallEntry] + ) -> List[ep.ModuleCallEntry]: + return [ + ep.ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.deserialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + +@final +class ExportedProgramDeserializer(metaclass=Final): + def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None): + self.expected_opset_version: Dict[str, int] = {} + if expected_opset_version: + self.expected_opset_version.update(expected_opset_version) + if "aten" not in self.expected_opset_version: + self.expected_opset_version["aten"] = torch._C._get_max_operator_version() + + def deserialize_range_constraints( + self, + symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges], + symbol_name_to_symbol: Dict[str, sympy.Symbol], + ) -> Dict[sympy.Symbol, ValueRanges]: + range_constraints = {} + for k, v in symbol_name_to_range.items(): + if symbol := symbol_name_to_symbol.get(k): + range_constraints[symbol] = v # type: ignore[arg-type] + else: + log.warning(f"Symbol {k} did not appear in the graph that was deserialized") # noqa: G004 + return range_constraints + + def deserialize( + self, + exported_program: ExportedProgram, + state_dict: Union[Dict[str, torch.Tensor], bytes], + constants: Union[Dict[str, torch.Tensor], bytes], + example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None, + ) -> ep.ExportedProgram: + assert isinstance(exported_program, ExportedProgram) + version = exported_program.schema_version + + # TODO(zhxchen17) blocked on thrift schema refactor + if version.major != SCHEMA_VERSION[0] and not (version.major == 0 and version.minor == 0): + raise SerializeError( + f"Serialized schema version {exported_program.schema_version} " + f"does not match our current schema version {SCHEMA_VERSION}." + ) + + symbol_name_to_range = { + k: symbolic_shapes.ValueRanges( + _int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val) + ) + for k, v in exported_program.range_constraints.items() + } + res = ( + GraphModuleDeserializer() + .deserialize( + exported_program.graph_module, + state_dict, + constants, + example_inputs, + symbol_name_to_range, + ) + ) + range_constraints = self.deserialize_range_constraints( + symbol_name_to_range, + res.names_to_symbols, + ) + + return ep.ExportedProgram( + root=res.graph_module, + graph=res.graph_module.graph, + graph_signature=res.signature, + state_dict=res.state_dict, # type: ignore[arg-type] + range_constraints=range_constraints, + module_call_graph=res.module_call_graph, + example_inputs=res.example_inputs, + constants=res.constants, + verifiers=[load_verifier(v) for v in exported_program.verifiers], + ) + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + return super().default(obj) + + +def _dataclass_to_dict(obj): + if isinstance(obj, _Union): + return {obj.type: _dataclass_to_dict(obj.value)} + elif dataclasses.is_dataclass(obj): + return { + f.name: _dataclass_to_dict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + if not (f.default is None and getattr(obj, f.name) is None) + } + elif isinstance(obj, list): + return [_dataclass_to_dict(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_dataclass_to_dict(x) for x in obj) + elif isinstance(obj, dict): + return {k: _dataclass_to_dict(v) for k, v in obj.items()} + else: + return obj + + +def serialize( + exported_program: ep.ExportedProgram, + opset_version: Optional[Dict[str, int]] = None, +) -> SerializedArtifact: + serialized_program = ExportedProgramSerializer(opset_version).serialize( + exported_program + ) + assert isinstance(serialized_program.exported_program, ExportedProgram) + + json_program = json.dumps( + _dataclass_to_dict(serialized_program.exported_program), cls=EnumEncoder + ) + json_bytes = json_program.encode("utf-8") + artifact = SerializedArtifact( + json_bytes, + serialized_program.state_dict, + serialized_program.constants, + serialized_program.example_inputs + ) + return artifact + + +def _dict_to_dataclass(cls, data): + assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." + if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): + if data is None: + return None + ty_args = typing.get_args(cls) + assert len(ty_args) == 2 + return _dict_to_dataclass(ty_args[0], data) + elif isinstance(cls, type) and issubclass(cls, _Union): + assert isinstance(data, dict) + assert len(data) == 1 + _type = next(iter(data.keys())) + _value = next(iter(data.values())) + assert isinstance(_type, str) + field_type = cls.__annotations__[_type] + return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) + elif dataclasses.is_dataclass(cls): + obj = cls(**data) # type: ignore[assignment] + type_hints = typing.get_type_hints(cls) + for f in dataclasses.fields(cls): + name = f.name + new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name)) + setattr(obj, name, new_field_obj) + return obj + elif isinstance(data, list): + if len(data) == 0: + return data + d_type = typing.get_args(cls)[0] + return [_dict_to_dataclass(d_type, d) for d in data] + elif isinstance(data, dict): + v_type = typing.get_args(cls)[1] + return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()} + return data + + +def deserialize( + artifact: SerializedArtifact, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ep.ExportedProgram: + assert isinstance(artifact.exported_program, bytes) + exported_program_str = artifact.exported_program.decode("utf-8") + exported_program_dict = json.loads(exported_program_str) + serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict) + return ( + ExportedProgramDeserializer(expected_opset_version) + .deserialize( + serialized_exported_program, + artifact.state_dict, + artifact.constants, + artifact.example_inputs, + ) + ) + + +def _canonicalize_graph( + sorted_inputs, sorted_outputs, graph +) -> Tuple[Graph, Dict[str, str]]: + def _get_argument(a: Argument): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return a.as_tensor + elif a.type == "as_tensors": + return a.as_tensors + elif a.type == "as_int": + return None + elif a.type == "as_ints": + return None + elif a.type == "as_float": + return None + elif a.type == "as_floats": + return None + elif a.type == "as_string": + return None + elif a.type == "as_strings": + return None + elif a.type == "as_sym_int": + return a.as_sym_int + elif a.type == "as_sym_ints": + return a.as_sym_ints + elif a.type == "as_scalar_type": + return None + elif a.type == "as_memory_format": + return None + elif a.type == "as_layout": + return None + elif a.type == "as_device": + return None + elif a.type == "as_bool": + return None + elif a.type == "as_bools": + return None + elif a.type == "as_sym_bool": + return a.as_sym_bool + elif a.type == "as_sym_bools": + return a.as_sym_bools + elif a.type == "as_graph": + return None + elif a.type == "as_optional_tensors": + return a.as_optional_tensors + elif a.type == "as_custom_obj": + return None + elif a.type == "as_operator": + return None + else: + raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") + + # Stage 1: Reorder named items. + def for_args(f, a): + assert isinstance(a, Argument) + pytree.tree_map(f, _get_argument(a)) + + def sort_nodes(nodes): + @dataclass + class Edges: + outs: List[int] + ins: int + + graph_inputs: Set[str] = set() + def_table: Dict[str, int] = {} + edges: Dict[int, Edges] = {} + candidates: List[Tuple[str, List[Tuple[str, List[int]]], int]] = [] + rank: Dict[str, int] = {} + ret: List[Node] = [] + + def get_name(a) -> Optional[str]: + if a is None: + return None + if isinstance(a, TensorArgument): + return a.name + elif isinstance(a, (SymIntArgument, SymBoolArgument)): + if a.type == "as_name": + return a.as_name + elif a.type in ("as_int", "as_bool"): + return None + else: + raise AssertionError(f"Unknown argument type: {a}") + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + return a.as_tensor.name + elif a.type == "as_none": + return None + else: + raise AssertionError(f"Unknown optional tensor type: {a}") + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + + def add_input(a): + if s := get_name(a): + graph_inputs.add(s) + + for_args(add_input, i) + + for idx, node in enumerate(nodes): + + def add_def(a): + if s := get_name(a): + assert s not in def_table + def_table[s] = idx + + for o in node.outputs: + for_args(add_def, o) + + edges[idx] = Edges([], 0) + + for idx, user in enumerate(nodes): + + def add_edge(a): + if s := get_name(a): + if s not in def_table: + assert s in graph_inputs + return + src = def_table[s] + edges[src].outs.append(idx) + edges[idx].ins += 1 + + for i in user.inputs: + for_args(add_edge, i.arg) + + def add_rank(a): + if s := get_name(a): + assert s not in rank + rank[s] = len(rank) + + def get_rank(a): + if s := get_name(a): + return rank[s] + else: + return -1 + + for i in sorted_inputs: + for_args(add_rank, i) + + def add_candidate(idx: int): + def get_ranks(i): + ranks = [] + for_args(lambda x: ranks.append(get_rank(x)), i) + return ranks + + node = nodes[idx] + args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] + heapq.heappush(candidates, (node.target, args_rank, idx)) + + for idx, e in edges.items(): + if e.ins == 0: + add_candidate(idx) + + while len(candidates) > 0: + _, _, idx = heapq.heappop(candidates) + node = nodes[idx] + for o in node.outputs: + for_args(add_rank, o) + ret.append(node) + assert idx in edges + for user in edges[idx].outs: + e = edges[user] + assert e.ins > 0 + e.ins -= 1 + if e.ins == 0: + add_candidate(user) + edges[idx].outs.clear() + + return ret + + sorted_nodes = sort_nodes(graph.nodes) + assert len(sorted_nodes) == len(graph.nodes) + + # Stage 2: Rename nodes. + name_table: Dict[str, str] = {} + + def rename_def(a): + def _rename(arg_name, values): + new_name = f"_{len(name_table)}" + assert arg_name not in name_table + name_table[arg_name] = new_name + assert arg_name in values + values[new_name] = values.pop(arg_name) + return new_name + + if a is None: + return + if isinstance(a, TensorArgument): + a.name = _rename(a.name, graph.tensor_values) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_int_values) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_bool_values) + else: + raise AssertionError(f"Unknown argument type: {a}") + + def replace_use(a): + if a is None: + return + if isinstance(a, TensorArgument): + a.name = name_table.get(a.name, a.name) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name) + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + for_args(rename_def, i) + + for n in sorted_nodes: + for o in n.outputs: + for_args(rename_def, o) + + for n in sorted_nodes: + for i in n.inputs: + for_args(replace_use, i.arg) + + for o in sorted_outputs: + for_args(replace_use, o) + + # Stage 3: Remove unstable fields. + for n in sorted_nodes: + n.metadata.clear() + + # Stage 4: Aggregate values. + sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=operator.itemgetter(0))) + sorted_sym_int_values = dict( + sorted(graph.sym_int_values.items(), key=operator.itemgetter(0)) + ) + sorted_sym_bool_values = dict( + sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0)) + ) + + # Stage 5: Recurse in subgraphs. + counter = 0 + for node in sorted_nodes: + for i in node.inputs: + a = i.arg + if a.type == "as_graph": + a.as_graph.graph, _ = _canonicalize_graph( + a.as_graph.graph.inputs, a.as_graph.graph.outputs, a.as_graph.graph + ) + a.as_graph.name = f"_g{counter}" + counter += 1 + + graph = Graph( + inputs=sorted_inputs, + outputs=sorted_outputs, + nodes=sorted_nodes, + tensor_values=sorted_tensor_values, + sym_int_values=sorted_sym_int_values, + sym_bool_values=sorted_sym_bool_values, + is_single_tensor_return=graph.is_single_tensor_return, + ) + return graph, name_table + + +def canonicalize(ep: ExportedProgram) -> ExportedProgram: + """ + Normalize a serialized ExportedProgram, so that different eager program which + shares the same semantics can get a single representation on disk. + + This function canonicalizes an ExportedProgram by: + + 1. Sorting nodes in topological order. + 2. Rename nodes to have unique names. + 3. Remove unstable fields. + 4. Aggregate the above program fields. + 5. Recurse in subgraphs. + + Args: + ep (ExportedProgram): The ExportedProgram to canonicalize. + + Returns: + ExportedProgram: The canonicalized exported program. + """ + ep = copy.deepcopy(ep) + + opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0))) + range_constraints = dict(sorted(ep.range_constraints.items(), key=operator.itemgetter(0))) + module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) + signature = ep.graph_module.signature + graph = ep.graph_module.graph + + assert len(graph.inputs) == len(signature.input_specs) + assert len(graph.outputs) == len(signature.output_specs) + + def rank_input(inp) -> Tuple[int, Optional[str], int]: + idx, (arg, spec) = inp + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + return 5, None, idx + elif spec.type == "parameter": + return 1, spec.parameter.parameter_name, idx + elif spec.type == "buffer": + return 2, spec.buffer.buffer_name, idx + elif spec.type == "tensor_constant": + return 3, spec.tensor_constant.tensor_constant_name, idx + elif spec.type == "custom_obj": + return 4, spec.custom_obj.custom_obj_name, idx + elif spec.type == "token": + return 0, None, idx + elif spec.type == "constant_input": + return 6, spec.constant_input.name, idx + else: + raise AssertionError(f"Unknown input type: {spec}") + + def rank_output(out) -> Tuple[int, Optional[str], int]: + idx, (arg, spec) = out + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + return 3, None, idx + elif spec.type == "loss_output": + return 3, None, idx + elif spec.type == "buffer_mutation": + return 1, spec.buffer_mutation.buffer_name, idx + elif spec.type == "gradient_to_parameter": + return 4, spec.gradient_to_parameter.parameter_name, idx + elif spec.type == "gradient_to_user_input": + return 5, None, idx + elif spec.type == "user_input_mutation": + return 2, None, idx + elif spec.type == "token": + return 0, None, idx + else: + raise AssertionError(f"Unknown output type: {spec}") + + sorted_ins = sorted( + enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input + ) + + if len(sorted_ins) > 0: + sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] + else: + sorted_inputs = () + input_specs = () + + sorted_outs = sorted( + enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output + ) + sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] + + sorted_graph, replace_table = _canonicalize_graph( + sorted_inputs, sorted_outputs, graph + ) + + def replace_input(inp): + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + arg = spec.user_input.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type in ( + "as_none", + "as_bool", + "as_int", + "as_float", + "as_string", + "as_custom_obj", + ): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "parameter": + t = spec.parameter.arg + t.name = replace_table[t.name] + elif spec.type == "buffer": + t = spec.buffer.arg + t.name = replace_table[t.name] + elif spec.type == "tensor_constant": + t = spec.tensor_constant.arg + t.name = replace_table[t.name] + elif spec.type == "custom_obj": + return + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + elif spec.type == "constant_input": + return + else: + raise AssertionError(f"Unknown input type: {spec}") + + def replace_output(out): + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + arg = spec.user_output.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type in ("as_none", "as_int", "as_float", "as_string"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "loss_output": + t = spec.loss_output.arg + t.name = replace_table[t.name] + elif spec.type == "buffer_mutation": + t = spec.buffer_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_parameter": + t = spec.gradient_to_parameter.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_user_input": + g = spec.gradient_to_user_input + g.arg.name = replace_table[g.arg.name] + g.user_input_name = replace_table[g.user_input_name] + elif spec.type == "user_input_mutation": + u = spec.user_input_mutation + u.arg.name = replace_table[u.arg.name] + u.user_input_name = replace_table[u.user_input_name] + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + else: + raise AssertionError(f"Unknown output type: {spec}") + + for spec in input_specs: + replace_input(spec) + + for spec in output_specs: + replace_output(spec) + + return ExportedProgram( + graph_module=GraphModule( + graph=sorted_graph, + signature=GraphSignature( + input_specs=list(input_specs), + output_specs=list(output_specs), + ), + module_call_graph=module_call_graph, + ), + opset_version=opset_version, + range_constraints=range_constraints, + schema_version=ep.schema_version, + verifiers=ep.verifiers, + torch_version=ep.torch_version, + ) + + +class ExtensionHandler: + """ + Base class for handling extension operators. + """ + @classmethod + def namespace(cls) -> str: + raise NotImplementedError(f"{cls.__class__} namespace() must be implemented") + + @classmethod + def to_op_name(cls, op) -> str: + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def from_op_name(cls, name: str): + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def op_schema(cls, op) -> torch.FunctionSchema: + raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented") + + +def register_extension( + op_type: Type[Any], + extension_handler: Type[ExtensionHandler], +): + """Register custom de/serialization method for a node with non-standard type.""" + assert issubclass(extension_handler, ExtensionHandler), f"Expected ExtensionHandler, got {extension_handler}." + assert op_type not in _serialization_registry, f"{op_type} is already registered." + assert isinstance(op_type, type) # Maybe a good idea to enforce this first. + assert not (op_type.__module__.startswith("torch") or op_type.__module__.startswith("builtins")) + assert extension_handler.namespace() not in _deserialization_registry + _serialization_registry[op_type] = extension_handler + _deserialization_registry[extension_handler.namespace()] = extension_handler + + +def _registered_extension_types(): + return tuple( + _serialization_registry.keys() + ) + + +# Registry to store all custom serialization implementations. +# The registry maps a operation to its serialization function (a callable), in their own +# namespace to avoid conflicts. +# Serialization: Op type --> custom handler. +# De-serialization: Namespace --> custom handler. +_serialization_registry: Dict[Type[Any], Type[ExtensionHandler]] = {} +_deserialization_registry: Dict[str, Type[ExtensionHandler]] = {} diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py b/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py new file mode 100644 index 0000000000000000000000000000000000000000..b129e8dd9a89ef4870d7ef3dc724aca8ccae3a43 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/union.py @@ -0,0 +1,70 @@ +# mypy: allow-untyped-defs +import functools +from dataclasses import fields +from typing import Hashable, Set + + +class _UnionTag(str): + _cls: Hashable + + @staticmethod + def create(t, cls): + tag = _UnionTag(t) + assert not hasattr(tag, "_cls") + tag._cls = cls + return tag + + def __eq__(self, cmp) -> bool: + assert isinstance(cmp, str) + other = str(cmp) + assert other in _get_field_names( + self._cls + ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + return str(self) == other + + def __hash__(self): + return hash(str(self)) + + +@functools.lru_cache(maxsize=None) +def _get_field_names(cls) -> Set[str]: + return {f.name for f in fields(cls)} + + +class _Union: + _type: _UnionTag + + @classmethod + def create(cls, **kwargs): + assert len(kwargs) == 1 + obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] + obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) + return obj + + def __post_init__(self): + assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] + + @property + def type(self) -> str: + try: + return self._type + except AttributeError as e: + raise RuntimeError( + f"Please use {type(self).__name__}.create to instantiate the union type." + ) from e + + @property + def value(self): + return getattr(self, self.type) + + def __getattribute__(self, name): + attr = super().__getattribute__(name) + if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] + raise AttributeError(f"Field {name} is not set.") + return attr + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py b/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..94c12c075a092b9f70db02e5f280f38c6f94f050 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py @@ -0,0 +1,135 @@ +# mypy: allow-untyped-defs +import os +import threading +from queue import Empty as EmptyQueue, Queue + +from torch._lazy.device_context import get_device_context + + +class ClosureHandler: + def __init__(self) -> None: + pass + + def run(self, closure): + """Run closure function + + Args: + closure: callable function to run + """ + closure() + + def __call__(self, closures): + for closure in closures: + self.run(closure) + + +class AsyncClosureHandler(ClosureHandler): + """Handler for Asynchronous Step Closures + Args: + max_queue_size: The maximum length of the closure queue after which + the training loop will block until closures are evaluated. + By default, a reasonable limit of a maximum of 100 on the queue. + This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment + variable. + """ + + def __init__(self, max_queue_size=100): + super().__init__() + self._closure_queue: Queue = Queue( + int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size)) + ) + self._closure_exception: Queue = Queue() + self._closure_lock = threading.Lock() + self._closure_event_loop_finished = threading.Event() + self._closure_event_loop = None + + def start_event_loop(self): + """Start closure event loop if not started""" + if self._closure_event_loop is None: + + def event_loop(): + # Run loop until closure event is set and closure queue is empty + while True: + try: + closure = self._closure_queue.get(block=True, timeout=3) + closure() + self._closure_queue.task_done() + except EmptyQueue: + with self._closure_lock: + if self._closure_queue.empty(): + self._closure_event_loop_finished.set() + return + except Exception as e: + self._closure_exception.put(e) + return + + self._closure_event_loop = threading.Thread(target=event_loop) + self._closure_event_loop.start() + + def run(self, closure): + with self._closure_lock: + self._closure_queue.put(closure, block=True) + if ( + self._closure_event_loop is None + or not self._closure_event_loop.is_alive() + ): + try: + e = self._closure_exception.get(block=False) + raise RuntimeError( + "Cannot run asynchronous closure due to previously raised exception" + ) from e + except EmptyQueue: + self._closure_event_loop = None + self.start_event_loop() + + +def add_step_closure(closure, args=(), run_async=False): + """Adds a closure to the list of the ones to be run at the end of the step. + Many times during model training there is the need to print/report (print to + console, post to tensorboard, etc...) information which require the content of + intermediary tensors to be inspected. + Inspecting different tensors content in different points of the model code + requires many executions and typically causes performance issues. + Adding a step closure will ensure that it will be run after the barrier, when + all the live tensors will be already materialized to device data. + Live tensors which will include the ones captured by the closure arguments. + So using `add_step_closure()` will ensure a single execution will be + performed, even when multiple closures are queued, requiring multiple tensors + to be inspected. + Step closures will be run sequentially in the order they have been queued. + Note that even though using this API the execution will be optimized, it is + advised to throttle the printing/reporting events once every N steps. + Args: + closure (callable): The function to be called. + args (tuple): The arguments to be passed to the closure. + run_async: If True, run the closure asynchronously. + """ + devctx = get_device_context() + closures_type = "async_step_closures" if run_async else "step_closures" + step_closures = getattr(devctx, closures_type, None) + if step_closures is None: + step_closures = [] + setattr(devctx, closures_type, step_closures) + step_closures.append(lambda a=args: closure(*a)) + + +def run_step_closures(): + devctx = get_device_context() + async_step_closures = getattr(devctx, "async_step_closures", None) + if async_step_closures is not None: + devctx.async_step_closures = [] + async_closure_handler = getattr(devctx, "async_closure_handler", None) + if async_closure_handler is None: + async_closure_handler = AsyncClosureHandler() + devctx.async_closure_handler = async_closure_handler + async_closure_handler(async_step_closures) + + step_closures = getattr(devctx, "step_closures", None) + if step_closures is not None: + devctx.step_closures = [] + closure_handler = getattr(devctx, "closure_handler", None) + if closure_handler is None: + closure_handler = ClosureHandler() + devctx.closure_handler = closure_handler + closure_handler(step_closures) + return devctx diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py b/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e654566f29bce166eb52e721b694f3b1f7862b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py @@ -0,0 +1,14 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def dump(dot_file_name: str): + """Dump TrieCache in the dot format""" + return torch._C._lazy._dump_ir_cache(dot_file_name) + + +def reset(): + """Clear TrieCache. This is needed in testing to avoid + node reusing between different tests. + """ + return torch._C._lazy._clear_ir_cache() diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py b/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8ddc8b11c7e036ba6beac440d04eb1835b26d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py @@ -0,0 +1,49 @@ +import torch + + +""" +tensor_factory_functions defines the list of torch functions that create tensors. +The list is grabbed by searching thru native_functions.yaml by the following +regular expression: + + cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor" + +It's possible that new tensor factory functions are added making this list stale. +Use at your own risk or regenerate the list. +""" +tensor_factory_functions = ( + torch._cudnn_init_dropout_state, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch._empty_affine_quantized, + torch.empty_strided, + torch.eye, + torch.full, + torch.from_file, + torch.hann_window, + torch.hamming_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.ones, + torch.scalar_tensor, + torch.rand, + torch.randint, + torch.randn, + torch.randperm, + torch.range, + torch._efficientzerotensor, + torch.zeros, + torch.tril_indices, + torch.triu_indices, + # Note: the following functions match the regular expression search above but + # they are not available in the torch module. Comment out. + # torch._sparse_coo_tensor_with_dims, + # torch.fft_fftfreq, + # torch.fft_rfftfreq, +) + ( + # torch.tensor is special since it's not in native_functions.yaml + # add it separately + torch.tensor, +) diff --git a/.venv/lib/python3.11/site-packages/torch/_prims_common/__init__.py b/.venv/lib/python3.11/site-packages/torch/_prims_common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61d0ba13b88f1515fa1cdd4534b166e55beaf826 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_prims_common/__init__.py @@ -0,0 +1,1996 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import operator +import warnings +from contextlib import nullcontext +from enum import Enum +from functools import reduce +from typing import ( + Any, + Callable, + cast, + List, + NamedTuple, + Optional, + overload, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) +from typing_extensions import deprecated, TypeAlias + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + +import torch +from torch import sym_float, sym_int, sym_max + + +ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]] +StrideType: TypeAlias = Union[List[int], Tuple[int, ...]] +DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]] +DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]] +# TODO: Type[torch.SymInt], Type[torch.SymFloat] +NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]] +# TODO: This needs a lot more type annotations +# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat] +NumberType: TypeAlias = Union[bool, int, float, complex] +RealNumberType: TypeAlias = Union[bool, int, float] + +Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat, torch.SymBool) +# I don't call it Integral because numbers.Integral includes bool, but IntLike +# does not +Dim = int +IntLike = (int, torch.SymInt) +FloatLike = (float, torch.SymFloat) +BoolLike = (bool, torch.SymBool) +IntWithoutSymInt = int +FloatWithoutSymFloat = float +DeviceLikeType: TypeAlias = Union[str, torch.device, int] +Tensor = torch.Tensor + + +torch_function_passthrough = { + torch.device, + torch.sym_not, + torch.sym_float, + torch.sym_int, + torch.sym_max, + torch.sym_min, + torch._sym_sqrt, # type: ignore[attr-defined] + torch.sym_ite, + torch.Tensor.dim, + torch.Tensor.ndim.__get__, # type: ignore[attr-defined] + torch.Tensor.numel, + torch.Tensor.size, + torch.Tensor.storage_offset, + torch.Tensor.stride, + torch.Tensor.dtype.__get__, # type: ignore[attr-defined] + torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined] + torch.Tensor.shape.__get__, # type: ignore[attr-defined] + torch.Tensor.device.__get__, # type: ignore[attr-defined] + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] + torch.Tensor.layout.__get__, # type: ignore[attr-defined] + torch.Tensor.is_contiguous, + # For TorchRefsMode only + torch.Tensor.__format__, + torch.Tensor.__repr__, + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] + torch.Tensor.__getitem__, +} + + +TensorLikeType = torch.Tensor +TensorLike = torch.Tensor +TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]] +TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType] + +CustomOutParamAnnotation = "__custom_out_param__" + + +def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if len(a) != len(b): + return False + + for x, y in zip(a, b): + if allow_rhs_unbacked: + # TODO: We should check that the symbols are consistent + # with each other + if isinstance(y, torch.SymInt): + continue + # NB: Naively, you would not expect to have to do an oblivious guard + # here because there is seemingly no broadcasting here, but in fact we + # use this in some situations to determine if we need to do an expand + # on the tensor because they don't line up, so you can definitely end + # up trying to prove u0 != 1 in this situation. See + # python test/test_proxy_tensor.py -k test_cumsum_unbacked + if guard_size_oblivious(x != y): + return False + + return True + + +def _maybe_get_pytype(t): + if t is torch.SymFloat: + return float + elif t is torch.SymInt: + return int + elif t is torch.SymBool: + return bool + else: + return t + + +# TODO: look at using torch.testing.assert_close instead with an option +# to just compare metadata +def compare_tensor_meta( + a: TensorLikeType, + b: TensorLikeType, + check_strides=False, + *, + allow_rhs_unbacked=False, + check_conj=True, +): + """ + Checks that two tensor likes have the same shape, + dtype and device. + + In the future this will validate additional metadata, like + strides. + """ + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + + if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked): + msg = f"Shapes {a.shape} and {b.shape} are not equal!" + raise AssertionError(msg) + + if a.dtype != b.dtype: + msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!" + raise AssertionError(msg) + + if a.device != b.device: + # Handles special cuda:0 vs cuda case + # TODO: we should review why this happens and see about fixing it + if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and ( + str(b.device) == "cuda:0" or str(b.device) == "cuda" + ): + pass + else: + msg = f"Devices {a.device} and {b.device} are not equal!" + raise AssertionError(msg) + + # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050 + if check_strides: + same_strides, idx = check_significant_strides(a, b) + if not same_strides: + msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!" + raise RuntimeError(msg) + + if a.storage_offset() != b.storage_offset(): + msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!" + raise RuntimeError(msg) + + if check_conj: + if a.is_conj() != b.is_conj(): + raise RuntimeError( + f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}" + ) + + if a.is_neg() != b.is_neg(): + raise RuntimeError( + f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}" + ) + + +def _check_strides_helper( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True +) -> Tuple[bool, Optional[int]]: + # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch + # See https://github.com/pytorch/pytorch/issues/77553 + # Only compares strides that are "meaningful" -- strides for dimensions with length > 1 + # and for tensors with more than one element + if ( + not only_cuda or a.device.type == "cuda" or b.device.type == "cuda" + ) and a.numel() > 0: + for idx in range(a.ndim): + check = not significant_only or a.shape[idx] > 1 + if a.stride()[idx] != b.stride()[idx] and check: + return False, idx + + return True, None + + +def check_significant_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True) + + +def check_all_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) + + +# This function is equivalent to compute_contiguous() from TensorImpl.cpp +def is_contiguous(a: TensorLikeType) -> bool: + """ + Tests whether a tensor is contiguous or not. + + Tensors are contiguous when they have no elements, + one element, or when they have "nested" strides. + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(a.numel() < 2): + return True + + expected_stride = 1 + for x, y in reversed(tuple(zip(a.shape, a.stride()))): + # Skips checking strides when a dimension has length 1 + if guard_size_oblivious(x == 1): + continue + + if guard_size_oblivious(y != expected_stride): + return False + expected_stride = expected_stride * x + + return True + + +# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp +def is_channels_last_contiguous_2d(a: Tensor) -> bool: + # NHWC or not channels last 2D contiguous + if a.ndim != 4: + return False + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + expected_stride = 1 + for idx in (1, 3, 2, 0): + length = a.shape[idx] + if guard_size_oblivious(length == 1): + continue + + stride = a.stride()[idx] + if guard_size_oblivious(stride != expected_stride): + return False + + expected_stride *= length + + return True + + +def is_channels_last_contiguous_3d(a: Tensor) -> bool: + # NDHWC or not channels last 3D contiguous + if a.ndim != 5: + return False + + expected_stride = 1 + for idx in (1, 4, 3, 2, 0): + length = a.shape[idx] + if length == 1: + continue + + stride = a.stride()[idx] + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +_memory_formats = { + torch.contiguous_format, + torch.preserve_format, + torch.channels_last, + torch.channels_last_3d, +} + + +def validate_memory_format(memory_format: torch.memory_format): + torch._check( + memory_format in _memory_formats, + lambda: f"Received unknown memory format {memory_format}!", + ) + + +def is_contiguous_for_memory_format( # type: ignore[return] + a: Tensor, *, memory_format: torch.memory_format +) -> bool: + validate_memory_format(memory_format) + + if memory_format == torch.contiguous_format: + return is_contiguous(a) + if memory_format == torch.channels_last: + return is_channels_last_contiguous_2d(a) + if memory_format == torch.channels_last_3d: + return is_channels_last_contiguous_3d(a) + + torch._check( + False, + lambda: f"is_contiguous received unsupported memory format {memory_format}", + ) + + +# NOTE: that tensors with no elements and channels last is ??? +def is_channels_last_contiguous(a: Tensor) -> bool: + """ + True when a tensor is channels-last contiguous. + + This requires that: + + - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions + - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the + stride of the 'C' dimension (Cs) is 1 and the strides corresponding to + each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are + "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension, + for example. + """ + return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) + + +def is_non_overlapping_and_dense(a: Tensor) -> bool: + """ + True when a tensor is non-overlapping and dense. + + A tensor is non-overlapping and dense when there exists a permutation of + its dimensions that is contiguous. + """ + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if a.is_sparse: + return False + + # Short-circuits if the tensor is already contiguous or channels-last contiguous + if is_contiguous(a) or is_channels_last_contiguous(a): + return True + + # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + if a.ndim == 1: + return a.stride()[0] == 1 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + # + # This sort is done in a size-oblivious way, which helps if we do a + # comparison like 2048*u0 > u0; we just want this to return True + # (and not worry about what if u0 is zero). + class K(NamedTuple): + size: int + stride: int + + def __lt__(self, other): + return guard_size_oblivious(self.stride < other.stride) + + def __gt__(self, other): + return guard_size_oblivious(self.stride > other.stride) + + def __le__(self, other): + return guard_size_oblivious(self.stride <= other.stride) + + def __ge__(self, other): + return guard_size_oblivious(self.stride >= other.stride) + + def __eq__(self, other): + return guard_size_oblivious(self.stride == other.stride) + + lengths_and_strides = sorted(map(K, a.shape, a.stride())) + + expected_stride = 1 + for length, stride in lengths_and_strides: + if guard_size_oblivious(length == 1): + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +# NOTE: Based on the implementation in TensorIterator.cpp, but note that +# the note [Computing output strides] is incorrect, because it +# says that strides will be preserved even if they are not +# "non overlapping and dense", but this is incorrect. The +# output of elementwise operations are always given +# non overlapping and dense strides. +# This is also INCORRECT because it does not model TensorIterator's +# short-circuit, which can cause different strides. +def compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=False +) -> List[int]: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if not _skip_checks and len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + if not _skip_checks: + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + if not _skip_checks: + tensors = tuple( + a + for a in tensors + if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return [] + + # Short-circuits for shapes with zero or one dimensions + # TODO: are these necessary? + ndim = tensors[0].ndim + if ndim == 0: + return [] + if ndim == 1: + return [0] + + # Short-circuits if contiguous or channels last, following the fake fast path. + # This reduces the number of guards we end up making + is_contiguous = True + is_channels_last = True + for t in tensors: + is_contiguous = is_contiguous and t.is_contiguous( + memory_format=torch.contiguous_format + ) + is_channels_last = is_channels_last and t.is_contiguous( + memory_format=torch.channels_last + ) + + if is_contiguous and not is_channels_last: + return list(range(ndim)) + + if is_channels_last and not is_contiguous: + return [0, *list(range(2, ndim)), 1] + + shape = tensors[0].shape + + def should_swap(idx_a, idx_b): + for tensor in tensors: + stride_a = tensor.stride()[idx_a] + stride_b = tensor.stride()[idx_b] + + if guard_size_oblivious(stride_a == 0) or guard_size_oblivious( + stride_b == 0 + ): + continue + + if guard_size_oblivious(stride_a < stride_b): + return -1 + + if guard_size_oblivious(stride_a > stride_b): + return 1 + + # stride_a == stride_b + if guard_size_oblivious(shape[idx_a] > shape[idx_b]): + return 1 + + # Note: this case is hit if all strides are zero, + # or all strides are equal and all dimensions have the same length + return 0 + + # The "sort" order for the permutation is back-to-front, but + # the natural order for permutations is front-to-back. Do the + # sorting back-to-front and then reverse it on output. + # + # also, note this returns the logical to physical shape permutation + perm = list(reversed(range(ndim))) + + # insertion sort with support for ambiguous comparisons + for i in range(1, ndim): + dim1 = i + for dim0 in reversed(range(i)): + comparison = should_swap(perm[dim0], perm[dim1]) + if comparison > 0: + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + dim1 = dim0 + elif comparison < 0: + break + + return list(reversed(perm)) + + +def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: + """ + Computes the output strides for elementwise operations. + """ + if len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + tensors = tuple( + a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return () + + ndim = tensors[0].ndim + shape = tensors[0].shape + + if ndim == 0: + return () + if ndim == 1: + return (1,) + + logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=True + ) + permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical + + new_strides = make_contiguous_strides_for(permuted_shape) + permuted_strides = apply_perm( + new_strides, invert_perm(logical_to_physical_perm) + ) # to logical + + return tuple(permuted_strides) + + +# Identity permutation is [0, 1, 2] +def apply_perm(inp, perm): + ndim = len(inp) + permuted_inp = [-1] * ndim + for idx, x in enumerate(perm): + permuted_inp[idx] = inp[x] + return permuted_inp + + +def invert_perm(perm): + ndim = len(perm) + new_perm = [-1] * ndim + for idx, x in enumerate(perm): + new_perm[x] = idx + return new_perm + + +# +# Common helper functions +# + + +def validate_dim_length(length: int): + """ + Validates that an object represents a valid + dimension length. + """ + + if isinstance(length, (int, torch.SymInt)): + torch._check_is_size(length) + else: + # sometimes called with sympy expression by inductor + assert length >= 0 + + +def validate_shape(shape: ShapeType): + """ + Validates that a sequence represents a valid shape. + """ + + assert isinstance(shape, Sequence), type(shape) + for l in shape: + validate_dim_length(l) + + +def validate_strides(strides: StrideType): + """ + Verifies the object specifies valid strides. + """ + + assert isinstance(strides, Sequence) + for stride in strides: + assert stride >= 0 + + +def validate_idx(rank: int, idx: int): + """ + Validates that idx is a valid index for the given shape. + Assumes the index is already canonicalized. + """ + + assert isinstance(idx, Dim) + assert isinstance(rank, Dim) + + assert idx >= 0 and idx < rank or idx == 0 + + +def validate_dimension_indices(rank: int, indices: DimsSequenceType): + for idx in indices: + validate_idx(rank, idx) + + +def validate_exclusive_idx(rank: int, ex_idx: int): + """ + Validates that ex_idx is a valid exclusive index + for the given shape. + """ + + assert isinstance(ex_idx, Dim) + assert isinstance(rank, Dim) + assert ex_idx > 0 and ex_idx <= rank + + +# "Wraps" a dim (up to one time) for the given rank, allowing dims to be +# specified using negative indices. If `wrap_scalar` is true then scalar +# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise, +# idx should be in the range [-rank, rank-1]. +def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: + if rank < 0: + msg = f"Rank cannot be negative but got {rank}" + raise IndexError(msg) + + if rank == 0: + if not wrap_scalar: + msg = f"Dimension specified as {idx} but tensor has no dimensions" + raise IndexError(msg) + rank = 1 + + if idx >= 0 and idx < rank: + return idx + + if idx < 0: + _idx = idx + rank + else: + _idx = idx + + if _idx < 0 or _idx >= rank: + # Same error message as in aten/src/ATen/WrapDimUtils.h:49 + msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})" + raise IndexError(msg) + + return _idx + + +# Takes a dimension or sequence of dimensions and "wraps" them, +# mapping negative offsets to positive ones +@overload +def canonicalize_dims( + rank: int, indices: Sequence[int], wrap_scalar: bool = True +) -> Tuple[int, ...]: + pass + + +@overload +def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int: + pass + + +def canonicalize_dims(rank, indices, wrap_scalar=True): + if isinstance(indices, Dim): + return canonicalize_dim(rank, indices, wrap_scalar) + + return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices) + + +def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool: + """ + Validates that perm is a permutation of length rank. + """ + + return isinstance(perm, Sequence) and sorted(perm) == list(range(rank)) + + +def is_same_shape(a: Sequence, b: Sequence) -> bool: + """ + Compares two shapes a and b, returning True if they are the same + (their ranks and corresponding lengths match) and False otherwise. + """ + + return tuple(a) == tuple(b) + + +def is_cpu_scalar_tensor(a: Any) -> bool: + return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu" + + +def check_same_device(*args, allow_cpu_scalar_tensors): + """ + Checks that all Tensors in args have the same device. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True + """ + # Short-circuits if all (one or fewer) arguments are trivially on the same device + if len(args) <= 1: + return + + # Note: cannot initialize device to the first arg's device (it may not have one) + device = None + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if device is None: + device = arg.device + + if device != arg.device: + msg = ( + "Tensor on device " + + str(arg.device) + + " is not on the expected device " + + str(device) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same device, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +def canonicalize_device(device: DeviceLikeType) -> torch.device: + if isinstance(device, torch.device): + return device + + assert isinstance(device, str) + return torch.device(device) + + +# Asserts if any of the following are true: +# - a non-scalar or non-Tensor is given +# - the shape of any tensors is distinct +def check_same_shape(*args, allow_cpu_scalar_tensors: bool): + """ + Checks that all Tensors in args have the same shape. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices + """ + shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + msg = f"Shape {arg.shape} is not the expected shape {shape}!" + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same shape, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Acquires a common shape, if it exists, from one or more tensor arguments, +# filtering number arguments +def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: + shape = None + scalar_shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + scalar_shape = arg.shape + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + return None + else: + return None + + return shape if shape is not None else scalar_shape + + +# Extracts dimensions that might be passed either as a list/tuple or as varargs. +# A typical case is Tensor.permute . +def extract_dims_from_varargs( + dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]] +) -> DimsSequenceType: + if dims and isinstance(dims[0], Sequence): + assert len(dims) == 1 + dims = cast(Tuple[DimsSequenceType], dims) + return dims[0] + else: + return cast(DimsSequenceType, dims) + + +def extract_shape_from_varargs( + shape: Union[ShapeType, Tuple[ShapeType]], + validate=True, +) -> Tuple[int, ...]: + """ + Returns a shape from varargs. + + In PyTorch, operations that accept shapes often accept them as varargs, like + foo(*shape). However a user can pass the shape as a sequence of integers, + like this: + + foo(1, 2, 3) + + or as a sequence of integers + + foo((1, 2, 3)) + + In the first case shape will be a tuple of integers, and in the second case it's a tuple + containing a tuple of integers. This validates those inputs and canonicalizes them + to a tuple of integers. + """ + + # Handles tuple unwrapping + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = shape[0] + + if validate: + validate_shape(shape) # type: ignore[arg-type] + return shape # type: ignore[return-value] + + +def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]: + ndim = max(len(a), len(b)) + expandedSizes = [0] * ndim + + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = len(a) - 1 - offset + dimB = len(b) - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + torch._check( + (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1), + lambda: ( + f"The size of tensor a ({sizeA}) must match the size of " + f"tensor b ({sizeB}) at non-jagged dimension {i}" + ), + ) + + # 1s map to the other size (even 0) + expandedSizes[i] = sizeB if sizeA == 1 else sizeA + + return tuple(expandedSizes) + + +def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: + """ + Infers the size of a dim with size -1, if it exists. + Also checks that new shape is compatible with the number of elements. + """ + dim = None + newsize = 1 + for i, d in enumerate(shape): + if d == -1: + torch._check(dim is None, lambda: "only one dimension can be inferred") + dim = i + elif d >= 0: + newsize *= d + else: + torch._check(False, lambda: f"invalid shape dimension {d}") + if dim is None: + torch._check( + numel == newsize, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + else: + from torch.fx.experimental.symbolic_shapes import definitely_true + + torch._check( + newsize != 0, + lambda: ( + f"cannot reshape tensor of 0 elements into shape {list(shape)} because the " + f"unspecified dimension size -1 can be any value and is ambiguous" + if definitely_true(numel == 0) + else f"shape '{list(shape)}' is invalid for input of size {numel}" + ), + ) + torch._check( + numel % newsize == 0, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + # Convert to list to produce a compatible error message with core + # PyTorch, which prints sequences in square brackets. + shape = list(shape) + shape[dim] = numel // newsize + # NB: This is pretty important when you have unbacked SymInts. + # Suppose you have (i0, 12) resizing into (2, -1, 12). The old + # range for i0 is typically [2, inf], which means if you divide + # by two the new range should be [1, inf]. But this is bad news + # if you have an unbacked SymInt: we need to reapply the unsound + # assumption that the size is >= 2. + torch._check_is_size(shape[dim]) + return tuple(shape) + + +_integer_dtypes = ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, +) +_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) +_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) + + +def is_boolean_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype is torch.bool + + +def is_integer_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _integer_dtypes + + +def is_low_precision_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _low_precision_dtypes + + +def is_float_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype.is_floating_point + + +def is_complex_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _complex_dtypes + + +def is_grad_dtype(dtype: torch.dtype) -> bool: + """ + Checks if the dtype can require a gradient. + """ + return dtype.is_floating_point or is_complex_dtype(dtype) + + +_complex_to_real_dtype_map = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +_real_to_complex_dtype_map = { + torch.float16: torch.complex32, + torch.bfloat16: torch.complex64, + torch.float32: torch.complex64, + torch.float64: torch.complex128, +} + + +def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype: + return _complex_to_real_dtype_map[dtype] + + +def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype: + return _real_to_complex_dtype_map[dtype] + + +def dtype_to_type(dtype: torch.dtype) -> type: + """ + Computes the corresponding Python type (AKA "type kind") for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return bool + if dtype in _integer_dtypes: + return int + if dtype.is_floating_point: + return float + if dtype in _complex_dtypes: + return complex + + raise ValueError("Invalid dtype!") + + +def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]: + """ + Computes the corresponding Python type constructor for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return lambda x: bool(x) + if dtype in _integer_dtypes: + return sym_int + if dtype.is_floating_point: + return sym_float + if dtype in _complex_dtypes: + # TODO: type error here is real, replace with sym_complex + return lambda x: complex(x) # type: ignore[arg-type] + + raise ValueError("Invalid dtype!") + + +def type_to_dtype(typ: type) -> torch.dtype: + """ + Computes the corresponding dtype for a Number type. + """ + + assert isinstance(typ, type) + + if typ in (bool, torch.SymBool): + return torch.bool + if typ in (int, torch.SymInt): + return torch.long + if typ in (float, torch.SymFloat): + return torch.get_default_dtype() + # TODO: sym_complex_float? + if typ is complex: + return corresponding_complex_dtype(torch.get_default_dtype()) + + raise ValueError(f"Invalid type {typ}!") + + +def get_dtype(x: Union[torch.Tensor, NumberType]): + if isinstance(x, torch.Tensor): + return x.dtype + else: + return type_to_dtype(type(x)) + + +_ordered_types = (bool, int, float, complex) + + +def check_fp_or_complex( + dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True +): + """ + Checks whether the input is floating point or complex. + If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 + """ + torch._check( + is_float_dtype(dtype) or is_complex_dtype(dtype), + lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", + ) + torch._check( + allow_low_precision_dtypes or not is_low_precision_dtype(dtype), + lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", + ) + + +def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): + torch._check( + len(A.shape) >= 2, + lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", + ) + + +def get_higher_type(a: type, b: type) -> type: + """ + Returns the higher of the two given Number types. + + The types are ordered bool -> int -> float -> complex. + """ + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + # Type checking + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + if a is b: + return a + + for typ in _ordered_types: + if a is typ: + return b + if b is typ: + return a + + raise ValueError("Unknown Python scalar type!") + + +# Returns the higher of two torch datatypes a and b or, if the two +# are not ordered relative to each other, the next +# higher datatype +def get_higher_dtype( + a: Optional[Union[torch.dtype, TensorLikeType, NumberType]], + b: Optional[Union[torch.dtype, TensorLikeType, NumberType]], +) -> Optional[torch.dtype]: + """ + Computes the "lowest" datatype that is weakly + "higher" than both a and b. + """ + + # Type checking + assert a is None or isinstance(a, (torch.dtype, TensorLike, Number)) + assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) + + def _extract_dtype( + x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] + ) -> Optional[torch.dtype]: + if x is None: + return None + if isinstance(x, torch.dtype): + return x + if isinstance(x, TensorLike): + return x.dtype + if isinstance(x, Number): + return type_to_dtype(type(x)) + + raise RuntimeError("Unexpected type given to _extract_dtype!") + + a, b = _extract_dtype(a), _extract_dtype(b) + + if a is b: + return a + + if a is None: + return b + + if b is None: + return a + + ordered_datatypes = ( + (torch.bool,), + (torch.uint8, torch.int8), + (torch.int16,), + (torch.int32,), + (torch.int64,), + (torch.float16, torch.bfloat16), + (torch.float32,), + (torch.float64,), + (torch.complex32,), + (torch.complex64,), + (torch.complex128,), + ) + + for idx, dtypes in enumerate(ordered_datatypes): + if a in dtypes and b in dtypes: + return ordered_datatypes[idx + 1][0] + if a in dtypes: + return b + if b in dtypes: + return a + + raise RuntimeError("Unexpected termination!") + + +def check_pin_memory(pin_memory: bool): + torch._check_not_implemented( + not pin_memory, lambda: "PrimTorch does not support pinned memory" + ) + + +def check_layout(layout: torch.layout): + torch._check_not_implemented( + layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}" + ) + + +# TODO: maybe unify with can_cast_to? +def is_weakly_lesser_type(a: type, b: type) -> bool: + """ + Compares two types, a and b, returning True if a is weakly "less" than b. + + The comparison is determined by the following type ordering: bool, int, float, complex. + """ + + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + for typ in _ordered_types: + if a == typ: + return True + if b == typ: + return False + + raise RuntimeError("Unexpected termination!") + + +def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool: + for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype): + if fn(cast_to): + return True + if fn(cast_from): + return False + + raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!") + + +def check_same_dtype(*args): + """ + Checks that all Tensors in args have the same device and that all Numbers have the + same corresponding Python type. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensors objects in args have different dtypes + - two Number objects in args have different types + - there are Tensors and Numbers in args, and one of those Tensors corresponding + Python types is different from the type of one of those Numbers + """ + full_dtype = None + scalar_type = None + + for arg in args: + if isinstance(arg, Number): + # Scalar type checking is disabled (and may be removed in the future) + continue + # if scalar_type is None: + # scalar_type = type(arg) + + # if scalar_type is not type(arg): + # msg = ( + # "Scalar of type " + # + str(type(arg)) + # + " is not the expected type of " + # + str(scalar_type) + # + "!" + # ) + # raise RuntimeError(msg) + elif isinstance(arg, TensorLike): + if full_dtype is None: + full_dtype = arg.dtype + if scalar_type is None: + scalar_type = dtype_to_type(arg.dtype) + + if full_dtype is not arg.dtype: + msg = ( + "Tensor with dtype " + + str(arg.dtype) + + " is not the expected dtype of " + + str(full_dtype) + + "!" + ) + raise RuntimeError(msg) + + arg_type = dtype_to_type(arg.dtype) + if arg_type is not scalar_type: + msg = ( + "Tensor with corresponding Python type " + + str(arg_type) + + " is not the expected type of " + + str(scalar_type) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same dtype, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Maps datatypes to their computation types for elementwise operations +_computation_dtype_map = { + torch.bfloat16: torch.float32, + torch.float16: torch.float32, + torch.complex32: torch.complex64, +} + + +def get_computation_dtype(dtype: torch.dtype) -> torch.dtype: + return _computation_dtype_map.get(dtype, dtype) + + +_cpu_acc_type_map = { + torch.bfloat16: torch.float64, + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.complex32: torch.complex128, + torch.complex64: torch.complex128, +} + + +def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype: + # Equivalent to at::toAccumulateType, prefer computation_dtype where possible + if device.type == "cpu": + return _cpu_acc_type_map.get(dtype, dtype) + else: + return get_computation_dtype(dtype) + + +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) + + +class REDUCTION_OUTPUT_TYPE_KIND(Enum): + SAME = (0,) + COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type + KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean + ALWAYS_BOOL = (3,) + + +# Describes the return type of the primitive: +# +# - NEW, a new tensor is created +# - VIEW, a view of an input tensor is returned +# - INPLACE, one or more input tensors is modified +# +# these descriptors are mututally exclusive and exhaustive. +class RETURN_TYPE(Enum): + NEW = (0,) + VIEW = (1,) + INPLACE = (2,) + NONE = (3,) + + +# TODO: when NumberType contains the sym types, can simplify this +def number_type( + x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool] +) -> Type: + if isinstance(x, torch.SymInt): + return int + elif isinstance(x, torch.SymFloat): + return float + elif isinstance(x, torch.SymBool): + return bool + else: + return type(x) + + +def expr_type(x: sympy.Basic) -> Type: + import sympy + + if x.kind is sympy.core.kind.BooleanKind: + return bool + elif x.is_integer: # type: ignore[attr-defined] + return int + else: + # NB: Not strictly correct, but we don't support SymPy complex or bool. + return float + + +# TODO: document type promotion kinds +def elementwise_dtypes( + *_args, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, +) -> Tuple[torch.dtype, torch.dtype]: + """ + Computes the computation and result dtypes for elementwise type promotion + on the given arguments and with the given elementwise type promotion kind. + + Note that not all inputs to an elementwise operation necessarily participate in type promotion. + For example, the "alpha" parameter of torch.add does not participate in type promotion, + although it may be cast to the Python type corresponding to the computation dtype that + the type promotion algorithm determines. + + Default elementwise type promotion, which all other type promotion kinds tweak (see below), + first decides which of four ordered types to use: + + bool -> integer -> floating point -> complex + + The selected type is the "lowest" type in the above list such that all number arguments + have a weakly "lower" type and all tensor arguments have a weakly lower corresponding + type for their dtype. + + Once the type is determined, the particular result dtype is found. The dtypes are + partially ordered as follows: + + bool -> uint8, int8 -> int16 -> int32 -> int64 -> + float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 + + The result dtype is selected by: + - if no tensor's dtype has the same corresponding type as the one selected, + then the result dtype is the (default) dtype corresponding to the selected type + (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) + - if the result type is complex then the dtype is: + - the default complex dtype if there are no floating point or complex tensors + - if there are floating point or complex tensors with one or more dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + (for example, double + cfloat -> cdouble) + - if there are only floating point or complex tensors with zero dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + - if the first two cases do not apply, the result dtype is the highest dtype among + all tensors with one or more dimensions of the output type, and if there are no such + tensors then it's the highest dtype among all tensors with zero dimensions of the output type + (for example, long + half -> half, even if the half tensor has zero dimensions) + + The "corresponding complex dtypes" are: + float16 -> complex32 + bfloat16 -> complex64 + float32 -> complex64 + float64 -> complex128 + complex32 -> complex32 + complex64 -> complex64 + complex128 -> complex128 + + The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation + dtype by mapping low precision floating point and complex dtypes as follows: + + float16 -> float32 + bfloat16 -> float32 + complex32 -> complex64 + + This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the + computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels + which perform no mathematical operations on their tensors (see below for examples). + + The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype, + and computation dtypes to the appropriate op math dtype. + + The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this + mapping: + + complex32 -> float16 + complex64 -> float32 + complex128 -> float64 + + Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does. + + The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long. + + The ALWAYS_BOOL type promotion kind always sets the result dtype to bool. + + Example operators for each type promotion option: + DEFAULT : add + NO_OPMATH : where, nextafter, cat + INT_TO_FLOAT : sin + COMPLEX_TO_FLOAT : abs + BOOL_TO_LONG : pow + ALWAYS_BOOL : eq + + """ + + args = tuple(x for x in _args if x is not None) + + highest_type: type = bool + + # Import sympy locally, as importing it eagerly at a module level is too slow + # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589 + import sympy + + for x in args: + if not isinstance(x, (Number, TensorLike, sympy.Basic)): + msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" + raise ValueError(msg) + + if isinstance(x, Number): + highest_type = get_higher_type(highest_type, number_type(x)) + elif isinstance(x, sympy.Basic): + highest_type = get_higher_type(highest_type, expr_type(x)) + else: + # x is a TensorLike + highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype)) + + result_dtype = None + + def _find_highest_dtype_filtered( + args, filter, *, float_as_complex=False + ) -> Optional[torch.dtype]: + zero_dim_tensor_dtype = None + one_plus_dim_tensor_dtype = None + for x in args: + if isinstance(x, TensorLike) and filter(x.dtype): + _dtype = x.dtype + if float_as_complex and is_float_dtype(_dtype): + _dtype = corresponding_complex_dtype(_dtype) + if x.ndim == 0: + zero_dim_tensor_dtype = get_higher_dtype( + zero_dim_tensor_dtype, _dtype + ) + else: + # x.ndim > 0 + one_plus_dim_tensor_dtype = get_higher_dtype( + one_plus_dim_tensor_dtype, _dtype + ) + + # Prefers dtype of tensors with one or more dimensions + if one_plus_dim_tensor_dtype is not None: + return one_plus_dim_tensor_dtype + + return zero_dim_tensor_dtype + + if highest_type is float: + result_dtype = _find_highest_dtype_filtered(args, is_float_dtype) + result_dtype = ( + torch.get_default_dtype() if result_dtype is None else result_dtype + ) + elif highest_type is complex: + result_dtype = _find_highest_dtype_filtered( + args, + lambda x: is_float_dtype(x) or is_complex_dtype(x), + float_as_complex=True, + ) + if result_dtype is None: + result_dtype = corresponding_complex_dtype(torch.get_default_dtype()) + elif highest_type is int: + result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype) + result_dtype = torch.long if result_dtype is None else result_dtype + else: + # highest_type is bool + result_dtype = torch.bool + + if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH: + return result_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: + if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype): + result_dtype = torch.get_default_dtype() + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: + # NOTE: computation can still occur in a complex dtype + computation_dtype = get_computation_dtype(result_dtype) + if is_complex_dtype(result_dtype): + result_dtype = corresponding_real_dtype(result_dtype) + return computation_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: + if is_boolean_dtype(result_dtype): + return torch.long, torch.long + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: + return get_computation_dtype(result_dtype), torch.bool + else: + raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}") + + +def reduction_dtypes( + arg, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.dtype, Optional[torch.dtype]]: + # even though some reductions, like amin or amax, don't strictly require type promotion, + # all the math ops (including comparisons) are still defined only for a computation type, + # so promotion will still happen. We are doing it explicitly here + inp_dtype = dtype if dtype is not None else arg.dtype + computation_dtype = get_computation_dtype(inp_dtype) + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME + or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ): + result_dtype = dtype if dtype else arg.dtype + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + and is_complex_dtype(result_dtype) + ): + result_dtype = corresponding_real_dtype(result_dtype) + elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE: + result_dtype = None + else: # ALWAYS_BOOL + result_dtype = torch.bool + return computation_dtype, result_dtype + + +# This function's logic is borrowed from the following functions defined in C++: +# batched_matrix_contiguous_strides and contiguous_strides +def make_contiguous_strides_for( + shape: ShapeType, row_major: bool = True +) -> Tuple[int, ...]: + """ + Returns the strides of a contiguous tensor if row_major + If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices + This is often used when calling external libraries like BLAS/LAPACK/cuSolver... + """ + # contiguous_strides from c10/util/strides.h + validate_shape(shape) + if not shape: + return () + + from torch.fx.experimental.symbolic_shapes import is_nested_int + + multiplier = 1 + strides = [] + for l in reversed(shape): + strides.append(multiplier) + multiplier *= l if is_nested_int(l) else sym_max(l, 1) + + result = tuple(reversed(strides)) + + # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h + if row_major: + return result + else: + if len(shape) < 2: + return result + return result[:-2] + (1, max(shape[-2], 1)) + + +def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + torch._check( + len(shape) == 3, + lambda: "Only tensors of rank 3 can use the channels_last_1d memory format", + ) + + multiplier = 1 + strides = [0] * 3 + for idx in (1, -1, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? + torch._check( + len(shape) == 4, + lambda: "Only tensors of rank 4 can use the channels_last memory format", + ) + + multiplier = 1 + strides = [0] * 4 + for idx in (1, -1, -2, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + torch._check( + len(shape) == 5, + lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", + ) + + multiplier = 1 + strides = [0] * 5 + for idx in (1, -1, -2, -3, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]: + ndim = len(shape) if isinstance(shape, Sequence) else 1 + if ndim == 3: + return make_channels_last_1d_strides_for(shape) + elif ndim == 4: + return make_channels_last_2d_strides_for(shape) + elif ndim == 5: + return make_channels_last_3d_strides_for(shape) + else: + raise RuntimeError( + f"no channels last format strides exist in {ndim} dimensions" + ) + + +def compute_reduction_output_shape( + shape: ShapeType, dimensions: Sequence +) -> Tuple[int, ...]: + for idx in dimensions: + validate_idx(len(shape), idx) + + new_shape = [] + for idx in range(len(shape)): + if idx in dimensions: + continue + + new_shape.append(shape[idx]) + + return tuple(new_shape) + + +def validate_no_repeating_dims(dims: Sequence): + if len(dims) != len(set(dims)): + raise RuntimeError("duplicate value in the list of dims") + + +def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]: + if dims is None: + return tuple(range(len(shape))) + dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims) + validate_no_repeating_dims(dims) + return dims + + +def set_correction( + unbiased: Optional[bool] = None, + correction: Optional[NumberType] = None, +) -> float: + if correction is not None and unbiased is not None: + raise RuntimeError("cannot specify both correction and unbiased arguments") + elif correction is None and unbiased is None: + correction = 1.0 + elif correction is None and unbiased is not None: + correction = 0.0 if unbiased is False else 1.0 + # NB: we don't actually support symint here, but it's harmless to accept + if not isinstance(correction, (IntLike, FloatLike)): + raise ValueError("correction argument should be integer or float") + if correction < 0: + raise ValueError("correction argument should be non-negative") + return sym_float(correction) + + +def compute_required_storage_length( + shape: ShapeType, strides: StrideType, storage_offset: int +) -> int: + """Computes the minimum storage size to hold the given tensor geometry. + + Example + ======= + + This is the size of a newly allocated tensor's storage, in units of elements + + >>> t = torch.empty((10, 20)) + >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset()) + 200 + + >>> # xdoctest: +SKIP(failing) + >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) + >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) + >>> size == t.storage().size() + True + + A valid tensor may have a larger storage size, but never smaller + + >>> slice = torch.empty(100)[20:40] + >>> slice.storage().size() + 100 + + >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) + 40 + + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # Short-circuits if the shape has no elements + if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0): + return 0 + + max_offset = sum((x - 1) * y for x, y in zip(shape, strides)) + # +1 to account for the first element which offsets are taken from + return 1 + storage_offset + max_offset + + +def check_in_bounds_for_storage( + a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int +): + """ + Determines if the given shape, strides, and offset are valid for the given storage. + """ + + required_length = compute_required_storage_length(shape, strides, storage_offset) + if a.size() < required_length: + msg = ( + f"Can't view a storage of size {a.size()} with an offset of {storage_offset}, " + f"shape of {str(shape)}, and strides of {str(strides)}, " + f"which requires a storage of size {required_length}" + ) + raise ValueError(msg) + + +# NOTE: This function should ideally be removed, but some Meta internal models +# packaged with `torch.package` are using it, so it will have to be removed +# at some point in the future when those models no longer use this function. +@deprecated( + "`torch._prims_common.check` is deprecated and will be removed in the future. " + "Please use `torch._check*` functions instead.", + category=FutureWarning, +) +def check( + b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError +) -> None: + """ + Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails. + Error message is a callable producing a string (to avoid wasting time + string formatting in non-error case, and also to make it easier for torchdynamo + to trace.) + + .. note:: This function is planned for removal in the future. Please use + `torch._check*` functions instead. + """ + torch._check_with(exc_type, b, s) + + +# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in +# c10/core/MemoryFormat.h into one function +def are_strides_like_channels_last( + shape: Sequence[int], strides: Sequence[int] +) -> bool: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + ndim = len(shape) + + if ndim == 4: + # Check for channels_last_2d + dim_order = [1, 3, 2, 0] + elif ndim == 5: + # Check for channels_last_3d + dim_order = [1, 4, 3, 2, 0] + else: + return False + + if guard_size_oblivious(strides[1] == 0): + return False + + min = 0 + for d in dim_order: + if guard_size_oblivious(shape[d] == 0): + return False + if guard_size_oblivious(strides[d] < min): + return False + if d == 0 and min == strides[1]: + return False + min = strides[d] + if guard_size_oblivious(strides[d] > 1): + min *= shape[d] + return True + + +def suggest_memory_format(x: TensorLikeType) -> torch.memory_format: + if x.layout != torch.strided: + return torch.contiguous_format + + if are_strides_like_channels_last(x.shape, x.stride()): + return torch.channels_last if x.ndim == 4 else torch.channels_last_3d + + return torch.contiguous_format + + +def prod(xs: Sequence[NumberType]) -> NumberType: + """Product of elements in input sequence. Returns 1 for empty sequence""" + return reduce(operator.mul, xs, 1) + + +def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool: + """Checks if a shape can be expanded to another shape. + This is equivalent to checking if the two shapes are broadcastable. + """ + # This is a Python implementation of + # aten/src/ATen/ExpandUtils.h:is_expandable_to + if len(shape) > len(desired): + return False + for i in range(len(shape)): + if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1: + return False + return True + + +def mask_tensor(mask: TensorLikeType, t: TensorLikeType): + """ + Similar to torch.where(mask, t, 0) but if t is boolean, + result is also boolean and not promoted to int. + """ + # torch.where(mask, t, False) is equivalent + # but feels hacky and might break in the future + if t.dtype is torch.bool: + return mask.logical_and(t) + else: + return torch.where(mask, t, 0) + + +def get_aten_op(fn: Callable, name: str): + """ + Given the __module__ of reference and its name, it returns + (our best guess of) the ATen name of the associated operation + + Note: In ATen, the __name__ of a function within a module often + starts by the module name. E.g. linalg_eigh, or special_zeta + """ + module = fn.__module__ + prefix = "torch._refs" + assert module.startswith(prefix) + module = module[len(prefix) :] + # We want to go from .special / .nn.functional + # to special and special_ / nn_functional_ + if module: + module = module[1:] + module = module.replace(".", "_") + module = module + "_" + return getattr(torch._ops.ops.aten, f"{module}{name}") + + +def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: + return dtype if dtype is not None else torch.get_default_dtype() + + +def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType: + return device if device is not None else torch.device("cpu") + + +def layout_or_default(layout: Optional[torch.layout]) -> torch.layout: + return layout if layout is not None else torch.strided + + +def clone_preserve_strides(x): + needed_size = compute_required_storage_length( + x.size(), x.stride(), x.storage_offset() + ) + # Our eager implementations for *_scatter ops are all primitives w.r.t autograd, + # so these as_strided() calls are not seen by autograd. + # We need to mimic this behavior in our ref/prim implementations. + # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided" + # We should revisit this when we add a compositional as_strided op, + # and also as part of https://github.com/pytorch/pytorch/issues/90507 + try: + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, True + ) + buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone() + return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old + ) + + +def alert_not_deterministic(caller: str): + if torch.are_deterministic_algorithms_enabled(): + if torch.is_deterministic_algorithms_warn_only_enabled(): + warnings.warn( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True, warn_only=True)'. " + f"You can file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ) + else: + torch._check( + False, + lambda: ( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True)'. You can turn off " + f"determinism just for this operation, or you can use the " + f"'warn_only=True' option, if that's acceptable for your application. " + f"You can also file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ), + ) + + +class CUDARngStateHelper: + @staticmethod + def get_torch_state_as_tuple(fake_mode=nullcontext()): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available") + + with fake_mode: + seed = torch.tensor(torch.cuda.initial_seed()) + offset = torch.tensor(torch.cuda._get_rng_state_offset()) + return seed, offset + + @staticmethod + def set_torch_state_tensor(seed, offset): + # Rng state is [64-bit seed, 64-bit offset] + seed_portion = seed.reshape([1]).view(torch.uint8) + offset_portion = offset.reshape([1]).view(torch.uint8) + new_state = torch.cat([seed_portion, offset_portion]) + torch.cuda.set_rng_state(new_state) + + @staticmethod + def set_new_offset(relative_offset): + torch.cuda._set_rng_state_offset(relative_offset.item()) diff --git a/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f48e9fbbb93891ad225b1c251b80ef7b413f5a2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..393ed1099dd89036b47c205d2e62cb244bd0e2d4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py b/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..a89ea7cb9997e68f5c42b765ff7e49a48a662848 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py @@ -0,0 +1,430 @@ +# mypy: allow-untyped-defs +import inspect +import warnings +from functools import wraps +from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple, TypeVar +from typing_extensions import ParamSpec + +import torch +import torch._prims_common as utils +from torch._prims_common import ( + CustomOutParamAnnotation, + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_flatten, tree_unflatten + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +@overload +def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + pass + + +@overload +def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: + pass + + +@overload +def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: + pass + + +@overload +def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: + pass + + +# TODO: implement ref.cast with an option to enforce safe casting +def _maybe_convert_to_dtype(a, dtype): + if isinstance(a, TensorLike): + if a.dtype != dtype: + return a.to(dtype) + return a + if isinstance(a, Number): + return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type] + if isinstance(a, Sequence): + return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) + # Passthrough None because some functions wrapped with type promotion + # wrapper might have optional args + if a is None: + return None + + raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") + + +def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: + if not isinstance(a, Number): + msg = f"Found unknown type {type(a)} when trying to convert scalars!" + raise ValueError(msg) + if not utils.is_weakly_lesser_type(type(a), typ): + msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!" + raise ValueError(msg) + + return typ(a) + + +def _annotation_has_type(*, typ, annotation): + if hasattr(annotation, "__args__"): + for a in annotation.__args__: + if _annotation_has_type(typ=typ, annotation=a): + return True + return False + + return typ is annotation + + +class elementwise_type_promotion_wrapper: + """ + Adds elementwise type promotion to a Python reference implementation. + + Takes two kwargs, type_promoting_args and type_promotion_kind. + + type_promoting_args must be a string Sequence specifiying the argument names of all + arguments that participate in type promotion (and should be type promoted). If the + arg specifies a Sequence-type then every element of the Sequence will participate in + type promotion. + + type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. + See its documentation for details. + + The return_dtype will be coerced to the wrapped function's dtype arg if it is available and + not None. + + Other type promotion behavior, like validating the Python type of scalar arguments, must + be handled separately. + """ + + def __init__( + self, + *, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, + type_promoting_args: Optional[Sequence[str]] = None, + ): + self.type_promoting_arg_names = type_promoting_args + self.type_promotion_kind = type_promotion_kind + + def __call__(self, fn: Callable) -> Callable: + sig = inspect.signature(fn) + + @wraps(fn) + def _fn(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + type_promoting_args = tuple( + bound.arguments[x] + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + ) + + flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) + compute_dtype, result_dtype = utils.elementwise_dtypes( + *flattened_type_promoting_args, + type_promotion_kind=self.type_promotion_kind, + ) + + promoted_args = { + x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + } + bound.arguments.update(promoted_args) + + result = fn(**bound.arguments) + + # Override the return_dtype if a dtype arg is present and not None + if "dtype" in bound.arguments: + maybe_dtype = bound.arguments["dtype"] + if maybe_dtype: # dtype cannot be None + result_dtype = maybe_dtype + + if isinstance(result, TensorLike): + return _maybe_convert_to_dtype(result, result_dtype) + if isinstance(result, Sequence): + return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) + raise AssertionError(f"Unhandled result type: {type(result)}") + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn + + +# Returns True if resize is necessary +def _resize_output_check(out: TensorLikeType, shape: ShapeType): + # If the shapes are correct there's nothing to do + if utils.same_shape(out.shape, shape): + return False + if out.numel() != 0: + msg = ( + f"An output with one or more elements was resized since it had shape {str(out.shape)} " + "which does not match the required output shape {str(shape)}. " + "This behavior is deprecated, and in a future PyTorch release outputs will not " + "be resized unless they have zero elements. " + "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." + ) + warnings.warn(msg) + return True + + +# TODO: handle tuples of tensors +def _maybe_resize_out( + out: TensorLikeType, + shape: ShapeType, + memory_format: Optional[torch.memory_format] = None, +): + if _resize_output_check(out, shape): + return out.resize_(shape, memory_format=memory_format) + else: + return out + + +def is_cpu_scalar(x: TensorLikeType) -> bool: + return x.dim() == 0 and x.device.type == "cpu" + + +def _safe_copy_out( + *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False +): + # Checks same device + if not is_cpu_scalar(copy_from) and copy_from.device != copy_to.device: + msg = ( + f"Attempting to copy from device {copy_from.device} " + f"to device {copy_to.device}, but cross-device copies are not allowed!" + ) + raise RuntimeError(msg) + + # Checks safe cast + if exact_dtype: + torch._check( + copy_from.dtype == copy_to.dtype, + lambda: f"Expected out tensor to have dtype {copy_from.dtype} " + f"but got {copy_to.dtype} instead", + ) + else: + torch._check( + utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), + lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " + "but this can't be cast because it is not safe!", + ) + + return copy_to.copy_(copy_from) + + +def out_wrapper( + *out_names: str, + exact_dtype: bool = False, + pass_is_out: bool = False, + preserve_memory_format: bool = False, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + # The wrapped function needs to convert the output parameters to ensure + # compatibility between the Python API (which always uses "out" as the + # parameter name and may be a tuple) and the Aten API (which may have + # multiple output parameters and use different parameter names such as + # "grad_input", "indices" or "values".) + + default_out_names = ("out",) + if len(out_names) == 0: + # Use default in out name + out_names = default_out_names + + is_tensor = len(out_names) == 1 + + def maybe_compute_memory_format(t): + return utils.suggest_memory_format(t) if preserve_memory_format else None + + def _out_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Adds the out parameter to a Python reference. + """ + out_type = ( + TensorLikeType + if is_tensor + else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] + ) + return_type = ( + TensorLikeType + if is_tensor + else NamedTuple( + f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] + ) + ) + + sig = inspect.signature(fn) + factory_kwargs = ("device", "dtype") + is_factory_fn = all(p in sig.parameters for p in factory_kwargs) + + @wraps(fn) + def _fn(*args: _P.args, out=None, **kwargs: _P.kwargs): + if is_factory_fn and out is not None: + for k in factory_kwargs: + out_attr = getattr(out, k) + if k not in kwargs: + kwargs[k] = out_attr + if pass_is_out: + result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type] + else: + result = fn(*args, **kwargs) + assert ( + isinstance(result, TensorLike) + and is_tensor + or isinstance(result, Tuple) # type: ignore[arg-type] + and len(result) == len(out_names) # type: ignore[arg-type] + ) + if out is not None: + # Naively you might expect this assert to be true, but + # it's not: + # + # assert type(out) == type(result) + # + # The reason is that functions under this wrapper can + # get registered to the Meta dispatch key, and that + # means they can be executed in a context where tensor + # subclasses are disabled (with no_dispatch), which is a + # handy way for an is-a tensor subclass (e.g., + # FakeTensor) to have the normal meta backend create a + # meta tensor, to be wrapped once it gets returned. + # In this situation, you will get a FakeTensor as + # the output tensor, but not the result--which will + # be a normal meta tensor, but this is perfectly + # harmless. + if is_tensor: + assert isinstance(out, TensorLike) + # These two operations are done in-place + _maybe_resize_out( + out, result.shape, maybe_compute_memory_format(result) # type: ignore[union-attr] + ) + _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + assert isinstance(out, Tuple) # type: ignore[arg-type] + torch._check_type( + len(out) == len(result), # type: ignore[arg-type] + lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type] + ) + for r, o in zip(result, out): # type: ignore[arg-type] + # These two operations are done in-place + _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r)) + _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + out = result + # mypy does not see through the definition of out_type given that it's in a different scope + return out if is_tensor else return_type(*out) # type: ignore[operator] + + out_param = inspect.Parameter( + "out", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_type, + ) + # Mark that the function now returns a tuple + assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( + sig.empty, + out_type, + ) + params = *sig.parameters.values(), out_param + + # If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear + # after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by + # Parameter.kind guarantees that all the parameters are in legal order. + params = sorted(params, key=lambda p: p.kind) + + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=return_type # type: ignore[arg-type] + ) + + _fn.__annotations__ = dict(getattr(fn, "__annotations__", {})) + _fn.__annotations__["out"] = out_type + _fn.__annotations__["return"] = return_type + + # In the special case of having a single tensor out parameter with a + # name other than out, add a special annotation to name the parameter + if is_tensor and out_names != default_out_names: + _fn.__annotations__[CustomOutParamAnnotation] = out_names[0] + + # Add an indicator attribute that can be used in special cases + # where having a function wrapped by `out_wrapper` is not desirable e.g. + # jit + _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] + + return _fn + + return _out_wrapper + + +def _maybe_remove_out_wrapper(fn: Callable): + return inspect.unwrap( + fn, + stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"), + ) + + +def backwards_not_supported(prim): + def redispatch_prim(args, kwargs): + with torch._C._AutoDispatchBelowAutograd(): + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + return prim(*args, **kwargs) + + class BackwardsNotSupported(torch.autograd.Function): + @staticmethod + def forward(ctx, args_spec, *flat_args): + args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type] + return redispatch_prim(args, kwargs) + + @staticmethod + def backward(ctx, *args): + raise RuntimeError("backwards not supported on prim") + + @wraps(prim) + def _autograd_impl(*args, **kwargs): + flat_args, args_spec = tree_flatten((args, kwargs)) + if torch.is_grad_enabled() and any( + a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) + ): + # TODO: There is a subtle bug here: prims like copy_to + # return their input argument after mutating it; and custom + # autograd function will incorrectly turn the result into + # a view which will fail test_python_ref_executor tests. + # At the moment, we sidestep this by observing that the + # unit tests don't ever try to run the executor with + # autograd, so we don't exercise the buggy case, but if + # you ever want to feed autograd through this, be aware + # of it! We need a way of properly implementing autograd + # for mutating operations in Python to do this. + return BackwardsNotSupported.apply(args_spec, *flat_args) + else: + return redispatch_prim(args, kwargs) + + return _autograd_impl + + +# TODO: when tracing this will add torch tensors and not TensorMeta objects +# to the trace -- we should fix this by adding a tracing context and NumberMeta classes +# TODO: this wrapper is currently untested +def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: + """ + Allows unary operators that accept tensors to work with Python numbers. + """ + sig = inspect.signature(fn) + + @wraps(fn) + def _fn(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], Number): + dtype = utils.type_to_dtype(type(args[0])) + args_ = list(args) + args_[0] = torch.tensor(args[0], dtype=dtype) + result = fn(*args_, **kwargs) + assert isinstance(result, torch.Tensor) + return result.item() + + return fn(*args, **kwargs) + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn diff --git a/.venv/lib/python3.11/site-packages/torch/contrib/__init__.py b/.venv/lib/python3.11/site-packages/torch/contrib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b854ff559bcae46785689656452bd651245cf04 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f00728bb1c7dfbfff13e520a8140a56feec93f83 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/contrib/_tensorboard_vis.py b/.venv/lib/python3.11/site-packages/torch/contrib/_tensorboard_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..ed1445dd7bce648bc4ac80a2782d72cf0faba2e0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/contrib/_tensorboard_vis.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +import time +from collections import defaultdict +from functools import partial +from typing import DefaultDict + +import torch + + +# Unfortunately it doesn't seem as if there was any way to get TensorBoard to do +# anything without having TF installed, and so this file has a hard dependency on it +# as well. It really is a debugging tool, so it doesn't matter. +try: + from tensorflow.core.util import event_pb2 + from tensorflow.core.framework import graph_pb2 + from tensorflow.python.summary.writer.writer import FileWriter +except ImportError: + raise ImportError("TensorBoard visualization of GraphExecutors requires having " + "TensorFlow installed") from None + + +def dump_tensorboard_summary(graph_executor, logdir): + with FileWriter(logdir) as w: + pb_graph = visualize(graph_executor) + evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString()) + w.add_event(evt) + + +def visualize(graph, name_prefix='', pb_graph=None, executors_it=None): + """Visualizes an independent graph, or a graph executor.""" + value_map = {} + pb_graph = pb_graph or graph_pb2.GraphDef() + + if isinstance(graph, torch._C.GraphExecutorState): + visualize_graph_executor(graph, name_prefix, pb_graph, + partial(visualize, pb_graph=pb_graph)) + return pb_graph + + # Set up an input node + input_node = pb_graph.node.add(op='input', name=name_prefix + 'input') + for i, value in enumerate(graph.param_node().outputs()): + value_map[value.unique()] = name_prefix + 'input:' + str(i) + + visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it) + + # Gather all outputs + return_node = pb_graph.node.add(op='output', name=name_prefix + 'output') + for value in graph.return_node().inputs(): + return_node.input.append(value_map[value.unique()]) + + return pb_graph + + +def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph): + """Append the state of a given GraphExecutor to the graph protobuf. + + Args: + state (GraphExecutor or GraphExecutorState): GraphExecutor to display. + name_prefix (str): Name prefix of the containing subgraph. + pb_graph (GraphDef): graph to append to. + inline_graph (Callable): a function that handles setting up a value_map, + so that some graphs in here can be inlined. This is necessary, because + this will simply be `visualize` for the top-level GraphExecutor, + or `inline_graph` for all nested ones. + + The signature should look like (Graph, name_prefix) -> (). + It will be called exactly once. + + The strategy is to embed all different configurations as independent subgraphs, + while inlining the original graph as the one that actually produces the values. + """ + if state.autograd_fallback_graph is not None: + visualize(graph=state.autograd_fallback_graph, + name_prefix=name_prefix + 'autograd_fallback/', + pb_graph=pb_graph, + executors_it=iter(state.autograd_fallback.executors())) + + for i, (arg_spec, plan) in enumerate(state.execution_plans.items()): + subgraph_name = name_prefix + f'plan{i}/' + + # Create a disconnected node that will keep information regarding the input + # types of this trace. This is unfortunately a bit too verbose to be included + # in the subgraph name. + input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name) + input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii') + + visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors())) + + # Show gradient as an independent subgraph of this plan + if plan.grad_executor is not None: + grad_subgraph_name = subgraph_name + 'grad/' + visualize(plan.grad_executor, grad_subgraph_name, pb_graph) + + return inline_graph(state.graph, name_prefix + 'original/') + + +def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None): + """Recursive part of visualize (basically skips setting up the input and output nodes).""" + def inline_graph(subgraph, name, node): + rec_value_map = {inp.unique(): value_map[val.unique()] + for inp, val in zip(subgraph.inputs(), node.inputs())} + visualize_rec(graph=subgraph, + value_map=rec_value_map, + name_prefix=name, + pb_graph=pb_graph) + for out, val in zip(subgraph.outputs(), node.outputs()): + value_map[val.unique()] = rec_value_map[out.unique()] + + op_id_counter: DefaultDict[str, int] = defaultdict(int) + + def name_for(node): + kind = node.kind()[node.kind().index('::') + 2:] + op_id_counter[kind] += 1 + return kind, name_prefix + kind + '_' + str(op_id_counter[kind]) + + def add_fusion_group(node): + op, name = name_for(node) + inline_graph(node.g('Subgraph'), name + '/', node) + + def add_graph_executor(node): + op, name = name_for(node) + if executors_it is None: + add_node(node) + else: + ge = next(executors_it) + visualize_graph_executor(ge, name + '/', pb_graph, + partial(inline_graph, node=node)) + + def add_node(node): + if node.kind() == 'prim::FusionGroup': + return add_fusion_group(node) + elif node.kind() == 'prim::GraphExecutor': + return add_graph_executor(node) + op, name = name_for(node) + pb_node = pb_graph.node.add(op=op, name=name) + for value in node.inputs(): + pb_node.input.append(value_map[value.unique()]) + # TODO: handle attrs + for i, value in enumerate(node.outputs()): + value_map[value.unique()] = name + ':' + str(i) + + for node in graph.nodes(): + add_node(node) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/subgraph_rewriter.py b/.venv/lib/python3.11/site-packages/torch/fx/subgraph_rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2cb743d2cdd9ba09c605432a144ccba97327da --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/subgraph_rewriter.py @@ -0,0 +1,348 @@ +from .graph_module import GraphModule +from .graph import Graph +from .node import Node +from ._symbolic_trace import symbolic_trace +from ._compatibility import compatibility + +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING +import torch + +if TYPE_CHECKING: + from .passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"] + +@compatibility(is_backward_compatible=True) +class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + +@compatibility(is_backward_compatible=False) +@dataclass +class ReplacedPatterns: + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + # List of nodes that were added into the graph + replacements: List[Node] + +def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: + gm.delete_all_unused_submodules() + + if isinstance(replacement, GraphModule): + replacement.graph.lint() + + def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: + module_path, _, attr_name = target.rpartition(".") + try: + mod: torch.nn.Module = gm.get_submodule(module_path) + except AttributeError: + return None + attr = getattr(mod, attr_name, None) + return attr + + for node in gm.graph.nodes: + if node.op == "call_module" or node.op == "get_attr": + + gm_attr = try_get_attr(gm, node.target) + replacement_attr = try_get_attr(replacement, node.target) + + # CASE 1: This target already exists as an attribute in our + # result GraphModule. Whether or not it exists in + # `replacement`, the existing submodule takes precedence. + if gm_attr is not None: + continue + + # CASE 2: The target exists as an attribute in `replacement` + # only, so we need to copy it over. + elif replacement_attr is not None: + new_attr = copy.deepcopy(replacement_attr) + if isinstance(replacement_attr, torch.nn.Module): + gm.add_submodule(node.target, new_attr) + else: + setattr(gm, node.target, new_attr) + + # CASE 3: The target doesn't exist as an attribute in `gm` + # or `replacement` + else: + raise RuntimeError('Attempted to create a "', node.op, + '" node during subgraph rewriting ' + f"with target {node.target}, but " + "the referenced attribute does not " + "exist in the replacement GraphModule") + + gm.graph.lint() + + +@compatibility(is_backward_compatible=True) +def replace_pattern( + gm: GraphModule, + pattern: Union[Callable, GraphModule], + replacement: Union[Callable, GraphModule] +) -> List[Match]: + """ + Matches all possible non-overlapping sets of operators and their + data dependencies (``pattern``) in the Graph of a GraphModule + (``gm``), then replaces each of these matched subgraphs with another + subgraph (``replacement``). + + Args: + ``gm``: The GraphModule that wraps the Graph to operate on + ``pattern``: The subgraph to match in ``gm`` for replacement + ``replacement``: The subgraph to replace ``pattern`` with + + Returns: + List[Match]: A list of ``Match`` objects representing the places + in the original graph that ``pattern`` was matched to. The list + is empty if there are no matches. ``Match`` is defined as: + + .. code-block:: python + + class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + + Examples: + + .. code-block:: python + + import torch + from torch.fx import symbolic_trace, subgraph_rewriter + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, w1, w2): + m1 = torch.cat([w1, w2]).sum() + m2 = torch.cat([w1, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + + def pattern(w1, w2): + return torch.cat([w1, w2]).sum() + + def replacement(w1, w2): + return torch.stack([w1, w2]) + + traced_module = symbolic_trace(M()) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + The above code will first match ``pattern`` in the ``forward`` + method of ``traced_module``. Pattern-matching is done based on + use-def relationships, not node names. For example, if you had + ``p = torch.cat([a, b])`` in ``pattern``, you could match + ``m = torch.cat([a, b])`` in the original ``forward`` function, + despite the variable names being different (``p`` vs ``m``). + + The ``return`` statement in ``pattern`` is matched based on its + value only; it may or may not match to the ``return`` statement in + the larger graph. In other words, the pattern doesn't have to extend + to the end of the larger graph. + + When the pattern is matched, it will be removed from the larger + function and replaced by ``replacement``. If there are multiple + matches for ``pattern`` in the larger function, each non-overlapping + match will be replaced. In the case of a match overlap, the first + found match in the set of overlapping matches will be replaced. + ("First" here being defined as the first in a topological ordering + of the Nodes' use-def relationships. In most cases, the first Node + is the parameter that appears directly after ``self``, while the + last Node is whatever the function returns.) + + One important thing to note is that the parameters of the + ``pattern`` Callable must be used in the Callable itself, + and the parameters of the ``replacement`` Callable must match + the pattern. The first rule is why, in the above code block, the + ``forward`` function has parameters ``x, w1, w2``, but the + ``pattern`` function only has parameters ``w1, w2``. ``pattern`` + doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. + As an example of the second rule, consider replacing + + .. code-block:: python + + def pattern(x, y): + return torch.neg(x) + torch.relu(y) + + with + + .. code-block:: python + + def replacement(x, y): + return torch.relu(x) + + In this case, ``replacement`` needs the same number of parameters + as ``pattern`` (both ``x`` and ``y``), even though the parameter + ``y`` isn't used in ``replacement``. + + After calling ``subgraph_rewriter.replace_pattern``, the generated + Python code looks like this: + + .. code-block:: python + + def forward(self, x, w1, w2): + stack_1 = torch.stack([w1, w2]) + sum_1 = stack_1.sum() + stack_2 = torch.stack([w1, w2]) + sum_2 = stack_2.sum() + max_1 = torch.max(sum_1) + add_1 = x + max_1 + max_2 = torch.max(sum_2) + add_2 = add_1 + max_2 + return add_2 + """ + match_and_replacements = _replace_pattern(gm, pattern, replacement) + return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements] + + +# Experimental API, not backward compatible +@compatibility(is_backward_compatible=False) +def replace_pattern_with_filters( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule], + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + ignore_literals: bool = False, +) -> List[ReplacedPatterns]: + """ + See replace_pattern for documentation. This function is an overload with an additional match_filter argument. + + Args: + ``match_filters``: A list of functions that take in + (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating + whether the match satisfies the condition. + See matcher_utils.py for definition of InternalMatch. + """ + + return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals) + + +def _replace_pattern( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule], + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + ignore_literals: bool = False, +) -> List[ReplacedPatterns]: + + from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch + + if match_filters is None: + match_filters = [] + + # Get the graphs for `gm`, `pattern`, `replacement` + original_graph: Graph = gm.graph + + if isinstance(pattern, GraphModule): + pattern_graph = pattern.graph + elif isinstance(pattern, Graph): + pattern_graph = pattern + else: + pattern_graph = symbolic_trace(pattern).graph + + if isinstance(replacement, GraphModule): + replacement_graph = replacement.graph + elif isinstance(replacement, Graph): + replacement_graph = replacement + else: + replacement_graph = symbolic_trace(replacement).graph + + matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, + remove_overlapping_matches=True, ignore_literals=ignore_literals) + _matches: List[InternalMatch] = matcher.match(original_graph) + + # Filter out matches that don't match the filter + _matches = [ + m for m in _matches + if all(match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters) + ] + + replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] + + # As we progressively replace nodes, we'll need to keep track of how the match results should change + match_changed_node: Dict[Node, Node] = {} + + match_and_replacements = [] + for match in _matches: + + # Build connecting between replacement graph's input and original graph input producer node + + # Initialize `val_map` with mappings from placeholder nodes in + # `replacement` to their corresponding node in `original_graph` + assert len(match.placeholder_nodes) == len(replacement_placeholders) + val_map: Dict[Node, Node] = {} + for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): + if isinstance(gn, Node): + val_map[rn] = match_changed_node.get(gn, gn) + if gn != val_map[rn]: + # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn + gn_ind = match.placeholder_nodes.index(gn) + match.placeholder_nodes[gn_ind] = match_changed_node[gn] + map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)] + match.nodes_map[map_key] = match_changed_node[gn] + else: + val_map[rn] = gn + + # Copy the replacement graph over + user_nodes: Set[Node] = set() + for n in match.returning_nodes: + user_nodes.update(n.users) + assert user_nodes, "The returning_nodes should have at least one user node" + + if len(user_nodes) == 1: + first_user_node = next(iter(user_nodes)) + else: + # If there are multiple user nodes, we need to find the first user node + # in the current execution order of the `original_graph` + for n in original_graph.nodes: + if n in user_nodes: + first_user_node = n + break + + with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined] + copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) + + if isinstance(copied_returning_nodes, Node): + copied_returning_nodes = (copied_returning_nodes, ) + + # Get a list of nodes that have been replaced into the graph + replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes] + + # Hook the output Node of the replacement subgraph in to the + # original Graph at the correct location + assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type] + for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type] + gn.replace_all_uses_with(copied_node) + match_changed_node[gn] = copied_node + # Remove the original nodes + for node in reversed(pattern_graph.nodes): + if node.op != "placeholder" and node.op != "output": + gn = match.nodes_map[node] + gm.graph.erase_node(gn) + + match_and_replacements.append( + ReplacedPatterns( + anchor=match.anchors[0], + nodes_map=match.nodes_map, + replacements=replacement_nodes + ) + ) + + # Update the passed-in GraphModule to reflect the new state of + # `original_graph` + gm.recompile() + + # If `replacement` was an nn.Module, we'll need to make sure that + # all the submodules have been copied over correctly + if isinstance(replacement, torch.nn.Module): + _replace_attributes(gm, replacement) + + return match_and_replacements diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..745c180d8c415c4e52472864d26381e0872f3354 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module. + +It registers custom reducers, that use shared memory to provide shared +views on the same data in different processes. Once the tensor/storage is moved +to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible +to send it to other processes without making any copies. + +The API is 100% compatible with the original module - it's enough to change +``import multiprocessing`` to ``import torch.multiprocessing`` to have all the +tensors sent through the queues or shared via other mechanisms, moved to shared +memory. + +Because of the similarity of APIs we do not document most of this package +contents, and we recommend referring to very good docs of the original module. +""" +import multiprocessing +import sys + +import torch + +from .reductions import init_reductions + + +__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"] + + +from multiprocessing import * # noqa: F403 + + +__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined] + + +# This call adds a Linux specific prctl(2) wrapper function to this module. +# See https://github.com/pytorch/pytorch/pull/14391 for more information. +torch._C._multiprocessing_init() + + +"""Add helper function to spawn N processes and wait for completion of any of +them. This depends `mp.get_context` which was added in Python 3.4.""" +from .spawn import ( + ENV_VAR_PARALLEL_START, + ProcessContext, + ProcessExitedException, + ProcessRaisedException, + spawn, + SpawnContext, + start_processes, +) + + +if sys.platform == "darwin" or sys.platform == "win32": + _sharing_strategy = "file_system" + _all_sharing_strategies = {"file_system"} +else: + _sharing_strategy = "file_descriptor" + _all_sharing_strategies = {"file_descriptor", "file_system"} + + +def set_sharing_strategy(new_strategy): + """Set the strategy for sharing CPU tensors. + + Args: + new_strategy (str): Name of the selected strategy. Should be one of + the values returned by :func:`get_all_sharing_strategies()`. + """ + global _sharing_strategy + assert new_strategy in _all_sharing_strategies + _sharing_strategy = new_strategy + + +def get_sharing_strategy(): + """Return the current strategy for sharing CPU tensors.""" + return _sharing_strategy + + +def get_all_sharing_strategies(): + """Return a set of sharing strategies supported on a current system.""" + return _all_sharing_strategies + + +def _set_thread_name(name: str) -> None: + """Set the name of the current thread. + + Args: + name (str): Name of the current thread. + """ + torch._C._set_thread_name(name) + + +def _get_thread_name() -> str: + """Get the name of the current thread. + + Returns: + str: Name of the current thread. + """ + return torch._C._get_thread_name() + + +init_reductions() diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf5483cdf86ad934e6a280a365a7f9e7527f6613 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10b2fd17e5185900f69d9f77c9c5ff18eb9cc524 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/_atfork.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..509e10ccdf530ccce4609c57c7e731fcb344d3cc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f03aae66f002d618323f51d091a560523b65e5d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b0c4b228f626ad60132d828b17cf7bec33b546d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2baaa083c796c7eb6cec1f4138266279010bb50 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc differ