|
|
|
|
|
|
|
|
import contextlib
|
|
|
import copy
|
|
|
import dataclasses
|
|
|
import functools
|
|
|
import operator
|
|
|
import types
|
|
|
import warnings
|
|
|
from collections import defaultdict, namedtuple
|
|
|
from collections.abc import Iterator
|
|
|
from contextlib import contextmanager
|
|
|
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
|
|
|
|
|
|
from torch._guards import tracing, TracingContext
|
|
|
from torch._higher_order_ops.utils import autograd_not_implemented
|
|
|
from torch._library.fake_class_registry import FakeScriptObject
|
|
|
from torch._subclasses.fake_impls import (
|
|
|
_deregister_op_impl,
|
|
|
_is_op_registered_to_fake_rule,
|
|
|
register_op_impl,
|
|
|
)
|
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
from torch.fx._symbolic_trace import _ConstantAttributeType
|
|
|
from torch.fx._utils import first_call_function_nn_module_stack
|
|
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
|
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
|
|
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sympy
|
|
|
|
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
|
|
import torch
|
|
|
import torch.utils._pytree as pytree
|
|
|
from torch._export.utils import (
|
|
|
_collect_all_valid_cia_ops,
|
|
|
_collect_and_set_constant_attrs,
|
|
|
_collect_param_buffer_metadata,
|
|
|
_detect_fake_mode_from_gm,
|
|
|
_fakify_params_buffers,
|
|
|
_get_decomp_for_cia,
|
|
|
_is_preservable_cia_op,
|
|
|
_name_hoo_subgraph_placeholders,
|
|
|
_override_graph_signature_for_temp_registered_constants,
|
|
|
_overwrite_signature_for_non_persistent_buffers,
|
|
|
_populate_param_buffer_metadata_to_new_gm,
|
|
|
_register_constants_as_buffers,
|
|
|
_rename_without_collisions,
|
|
|
_special_op_to_preserve_cia,
|
|
|
placeholder_naming_pass,
|
|
|
)
|
|
|
from torch._export.verifier import Verifier
|
|
|
from torch._guards import detect_fake_mode
|
|
|
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
|
|
from torch.export._tree_utils import is_equivalent, reorder_kwargs
|
|
|
from torch.export.decomp_utils import CustomDecompTable
|
|
|
from torch.fx._compatibility import compatibility
|
|
|
from torch.fx.passes.infra.pass_base import PassResult
|
|
|
from torch.fx.passes.infra.pass_manager import PassManager
|
|
|
|
|
|
from .graph_signature import (
|
|
|
ArgumentSpec,
|
|
|
ConstantArgument,
|
|
|
CustomObjArgument,
|
|
|
ExportGraphSignature,
|
|
|
InputKind,
|
|
|
InputSpec,
|
|
|
OutputKind,
|
|
|
OutputSpec,
|
|
|
SymBoolArgument,
|
|
|
SymFloatArgument,
|
|
|
SymIntArgument,
|
|
|
TensorArgument,
|
|
|
TokenArgument,
|
|
|
)
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
"ExportedProgram",
|
|
|
"ModuleCallEntry",
|
|
|
"ModuleCallSignature",
|
|
|
"default_decompositions",
|
|
|
]
|
|
|
|
|
|
|
|
|
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
class ModuleCallSignature:
|
|
|
inputs: list[ArgumentSpec]
|
|
|
outputs: list[ArgumentSpec]
|
|
|
in_spec: pytree.TreeSpec
|
|
|
out_spec: pytree.TreeSpec
|
|
|
forward_arg_names: Optional[list[str]] = None
|
|
|
|
|
|
def replace_all_uses_with(self, original_node, new_node):
|
|
|
for i in self.inputs:
|
|
|
if i.name == original_node.name:
|
|
|
i.name = new_node.name
|
|
|
for o in self.outputs:
|
|
|
if o.name == original_node.name:
|
|
|
o.name = new_node.name
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
class ModuleCallEntry:
|
|
|
fqn: str
|
|
|
signature: Optional[ModuleCallSignature] = None
|
|
|
|
|
|
|
|
|
def _disable_prexisiting_fake_mode(fn):
|
|
|
@functools.wraps(fn)
|
|
|
def wrapper(*args, **kwargs):
|
|
|
with unset_fake_temporarily():
|
|
|
return fn(*args, **kwargs)
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
def _fx_collection_equivalence_fn(
|
|
|
spec1_type: Optional[type],
|
|
|
spec1_context: pytree.Context,
|
|
|
spec2_type: Optional[type],
|
|
|
spec2_context: pytree.Context,
|
|
|
) -> bool:
|
|
|
"""Treat containers and their immutable variants as the same type. Otherwise
|
|
|
compare as normal.
|
|
|
"""
|
|
|
if spec1_type is None or spec2_type is None:
|
|
|
return spec1_type is spec2_type and spec1_context == spec2_context
|
|
|
|
|
|
if issubclass(spec1_type, (dict, immutable_dict)) and issubclass(
|
|
|
spec2_type, (dict, immutable_dict)
|
|
|
):
|
|
|
return spec1_context == spec2_context
|
|
|
|
|
|
if issubclass(spec1_type, (list, immutable_list)) and issubclass(
|
|
|
spec2_type, (list, immutable_list)
|
|
|
):
|
|
|
return spec1_context == spec2_context
|
|
|
|
|
|
return spec1_type is spec2_type and spec1_context == spec2_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [
|
|
|
torch._C.DispatchKey.AutogradCPU,
|
|
|
torch._C.DispatchKey.AutogradCUDA,
|
|
|
torch._C.DispatchKey.AutogradMeta,
|
|
|
torch._C.DispatchKey.AutogradXLA,
|
|
|
torch._C.DispatchKey.AutogradLazy,
|
|
|
torch._C.DispatchKey.AutogradIPU,
|
|
|
torch._C.DispatchKey.AutogradXPU,
|
|
|
torch._C.DispatchKey.AutogradMPS,
|
|
|
torch._C.DispatchKey.AutogradHPU,
|
|
|
torch._C.DispatchKey.AutogradPrivateUse1,
|
|
|
torch._C.DispatchKey.AutogradPrivateUse2,
|
|
|
torch._C.DispatchKey.AutogradPrivateUse3,
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_BACKEND_KEYS_TO_OVERRIDE = [
|
|
|
torch._C.DispatchKey.CPU,
|
|
|
torch._C.DispatchKey.CUDA,
|
|
|
torch._C.DispatchKey.Meta,
|
|
|
torch._C.DispatchKey.XLA,
|
|
|
torch._C.DispatchKey.Lazy,
|
|
|
torch._C.DispatchKey.IPU,
|
|
|
torch._C.DispatchKey.XPU,
|
|
|
torch._C.DispatchKey.MPS,
|
|
|
torch._C.DispatchKey.HPU,
|
|
|
]
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def _override_composite_implicit_decomp(cia_ops_to_callable):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
saved_tables = {}
|
|
|
patched_ops = set()
|
|
|
for op_overload, decomp_callable in cia_ops_to_callable.items():
|
|
|
saved_tables[op_overload] = op_overload.py_kernels.copy()
|
|
|
patched_ops.add(op_overload)
|
|
|
for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE:
|
|
|
if override_dispatch_key not in op_overload.py_kernels:
|
|
|
|
|
|
op_overload.py_impl(override_dispatch_key)(
|
|
|
autograd_not_implemented(op_overload, deferred_error=True)
|
|
|
)
|
|
|
|
|
|
|
|
|
orig_cia_callable = _get_decomp_for_cia(op_overload)
|
|
|
if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels:
|
|
|
del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]
|
|
|
|
|
|
op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(
|
|
|
decomp_callable
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _force_dispatch_to_orig_cia_callable(fake_tensor_mode, op, *args, **kwargs):
|
|
|
orig_cia_callable = kwargs["original_callable"]
|
|
|
del kwargs["original_callable"]
|
|
|
with fake_tensor_mode:
|
|
|
return orig_cia_callable(*args, **kwargs)
|
|
|
|
|
|
if not _is_op_registered_to_fake_rule(op_overload):
|
|
|
register_op_impl(op_overload)(
|
|
|
functools.partial(
|
|
|
_force_dispatch_to_orig_cia_callable,
|
|
|
original_callable=orig_cia_callable,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
for key in _BACKEND_KEYS_TO_OVERRIDE:
|
|
|
if key not in op_overload.py_kernels:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
op_overload.py_impl(key)(orig_cia_callable)
|
|
|
|
|
|
try:
|
|
|
yield
|
|
|
finally:
|
|
|
for op in patched_ops:
|
|
|
op.py_kernels.clear()
|
|
|
op.py_kernels.update(saved_tables[op])
|
|
|
op._dispatch_cache.clear()
|
|
|
_deregister_op_impl(op)
|
|
|
|
|
|
|
|
|
def _split_decomp_table_to_cia_and_python_decomp(
|
|
|
decomp_table: dict[torch._ops.OperatorBase, Callable],
|
|
|
) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]:
|
|
|
all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
|
|
|
cia_ops_to_callable = {}
|
|
|
|
|
|
for op in list(decomp_table.keys()):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if op in all_preservable_cia_ops:
|
|
|
cia_ops_to_callable[op] = decomp_table[op]
|
|
|
all_preservable_cia_ops.remove(op)
|
|
|
del decomp_table[op]
|
|
|
|
|
|
|
|
|
|
|
|
elif _is_preservable_cia_op(op):
|
|
|
op_name = op.name()
|
|
|
assert not op_name.startswith("aten"), "This should be a custom op"
|
|
|
cia_ops_to_callable[op] = decomp_table[op]
|
|
|
|
|
|
|
|
|
|
|
|
for k in all_preservable_cia_ops:
|
|
|
cia_ops_to_callable[k] = _special_op_to_preserve_cia
|
|
|
|
|
|
return cia_ops_to_callable, decomp_table
|
|
|
|
|
|
|
|
|
def default_decompositions() -> "CustomDecompTable":
|
|
|
"""
|
|
|
This is the default decomposition table which contains decomposition of
|
|
|
all ATEN operators to core aten opset. Use this API together with
|
|
|
:func:`run_decompositions()`
|
|
|
"""
|
|
|
return CustomDecompTable()
|
|
|
|
|
|
|
|
|
def _decompose_and_get_gm_with_new_signature_constants(
|
|
|
ep,
|
|
|
*,
|
|
|
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
|
|
|
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
|
|
|
joint_loss_index: Optional[int],
|
|
|
decompose_custom_triton_ops,
|
|
|
):
|
|
|
from torch._export.passes.lift_constants_pass import _materialize_and_lift_constants
|
|
|
from torch._functorch.aot_autograd import aot_export_module
|
|
|
from torch.export._trace import (
|
|
|
_disable_custom_triton_op_functional_decomposition,
|
|
|
_export_to_aten_ir,
|
|
|
_ignore_backend_decomps,
|
|
|
_verify_nn_module_stack,
|
|
|
_verify_placeholder_names,
|
|
|
_verify_stack_trace,
|
|
|
)
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
|
|
def _is_joint_ir_decomp(ep, joint_loss_index):
|
|
|
return (
|
|
|
joint_loss_index is not None
|
|
|
or ep.graph_signature.backward_signature is not None
|
|
|
)
|
|
|
|
|
|
if not _is_joint_ir_decomp(ep, joint_loss_index):
|
|
|
mod = ep.module()
|
|
|
|
|
|
wrapped_params_buffers = {
|
|
|
**dict(mod.named_parameters(remove_duplicate=False)),
|
|
|
**dict(mod.named_buffers(remove_duplicate=False)),
|
|
|
}
|
|
|
|
|
|
from torch._functorch._aot_autograd.subclass_parametrization import (
|
|
|
unwrap_tensor_subclass_parameters,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unwrap_tensor_subclass_parameters(mod)
|
|
|
unwrapped_params_buffers = {
|
|
|
**dict(mod.named_parameters(remove_duplicate=False)),
|
|
|
**dict(mod.named_buffers(remove_duplicate=False)),
|
|
|
}
|
|
|
|
|
|
|
|
|
fake_mode = _detect_fake_mode_from_gm(ep.graph_module)
|
|
|
if fake_mode is None:
|
|
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True)
|
|
|
|
|
|
|
|
|
out_spec = mod._out_spec
|
|
|
|
|
|
orig_arg_names = mod.graph._codegen.pytree_info.orig_args
|
|
|
|
|
|
|
|
|
if out_spec.type not in (list, tuple):
|
|
|
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
|
|
|
|
|
|
mod.graph._codegen = _PyTreeCodeGen(
|
|
|
_PyTreeInfo(
|
|
|
orig_arg_names,
|
|
|
mod._in_spec,
|
|
|
out_spec,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
mod.recompile()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_collect_and_set_constant_attrs(ep.graph_signature, ep.constants, mod)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temp_registered_constants = _register_constants_as_buffers(
|
|
|
mod, ep.state_dict, ep.graph_signature.non_persistent_buffers
|
|
|
)
|
|
|
|
|
|
|
|
|
fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
|
|
|
|
|
|
params_buffers_to_node_meta = _collect_param_buffer_metadata(mod)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch._export.non_strict_utils import (
|
|
|
_enable_graph_inputs_of_type_nn_module,
|
|
|
_fakify_script_objects,
|
|
|
)
|
|
|
|
|
|
retracing_args = []
|
|
|
for node in mod.graph.nodes:
|
|
|
if node.op == "placeholder":
|
|
|
if isinstance(node.meta["val"], CustomObjArgument):
|
|
|
real_script_obj = None
|
|
|
if node.meta["val"].fake_val is None:
|
|
|
real_script_obj = ep.constants[node.meta["val"].name]
|
|
|
else:
|
|
|
real_script_obj = node.meta["val"].fake_val.real_obj
|
|
|
retracing_args.append(real_script_obj)
|
|
|
else:
|
|
|
retracing_args.append(node.meta["val"])
|
|
|
|
|
|
tx = TracingContext(fake_mode)
|
|
|
|
|
|
with (
|
|
|
fake_mode,
|
|
|
_override_composite_implicit_decomp(
|
|
|
cia_to_decomp,
|
|
|
),
|
|
|
_enable_graph_inputs_of_type_nn_module(ep.example_inputs),
|
|
|
tracing(tx),
|
|
|
):
|
|
|
retracing_args_unwrapped = pytree.tree_unflatten(
|
|
|
retracing_args, mod._in_spec
|
|
|
)
|
|
|
|
|
|
with _fakify_script_objects(
|
|
|
mod,
|
|
|
(
|
|
|
*retracing_args_unwrapped[0],
|
|
|
*retracing_args_unwrapped[1].values(),
|
|
|
),
|
|
|
{},
|
|
|
fake_mode,
|
|
|
) as (
|
|
|
patched_mod,
|
|
|
new_fake_args,
|
|
|
new_fake_kwargs,
|
|
|
new_fake_constant_attrs,
|
|
|
map_fake_to_real,
|
|
|
):
|
|
|
aten_export_artifact = _export_to_aten_ir(
|
|
|
patched_mod,
|
|
|
new_fake_args,
|
|
|
new_fake_kwargs,
|
|
|
fake_params_buffers,
|
|
|
new_fake_constant_attrs,
|
|
|
decomp_table=python_decomp_table,
|
|
|
_check_autograd_state=False,
|
|
|
_prettify_placeholder_names=False,
|
|
|
decompose_custom_triton_ops=decompose_custom_triton_ops,
|
|
|
)
|
|
|
|
|
|
|
|
|
aten_export_artifact.constants = {
|
|
|
fqn: (
|
|
|
map_fake_to_real[obj]
|
|
|
if isinstance(obj, FakeScriptObject)
|
|
|
else obj
|
|
|
)
|
|
|
for fqn, obj in aten_export_artifact.constants.items()
|
|
|
}
|
|
|
|
|
|
gm = aten_export_artifact.gm
|
|
|
new_graph_signature = aten_export_artifact.sig
|
|
|
|
|
|
|
|
|
|
|
|
new_graph_signature = (
|
|
|
_override_graph_signature_for_temp_registered_constants(
|
|
|
new_graph_signature, temp_registered_constants
|
|
|
)
|
|
|
)
|
|
|
|
|
|
_populate_param_buffer_metadata_to_new_gm(
|
|
|
params_buffers_to_node_meta, gm, new_graph_signature
|
|
|
)
|
|
|
|
|
|
|
|
|
new_graph_signature = _overwrite_signature_for_non_persistent_buffers(
|
|
|
ep.graph_signature, new_graph_signature
|
|
|
)
|
|
|
|
|
|
constants = _materialize_and_lift_constants(
|
|
|
gm, new_graph_signature, new_fake_constant_attrs
|
|
|
)
|
|
|
|
|
|
placeholder_naming_pass(
|
|
|
gm,
|
|
|
new_graph_signature,
|
|
|
patched_mod,
|
|
|
new_fake_args,
|
|
|
new_fake_kwargs,
|
|
|
fake_params_buffers,
|
|
|
constants,
|
|
|
)
|
|
|
|
|
|
_verify_nn_module_stack(gm)
|
|
|
_verify_stack_trace(gm)
|
|
|
_verify_placeholder_names(gm, new_graph_signature)
|
|
|
|
|
|
gm, new_graph_signature = _remove_unneccessary_copy_op_pass(
|
|
|
gm, new_graph_signature
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_state_dict = {
|
|
|
**ep.state_dict,
|
|
|
**{
|
|
|
name: p
|
|
|
for name, p in unwrapped_params_buffers.items()
|
|
|
if name not in wrapped_params_buffers
|
|
|
},
|
|
|
}
|
|
|
|
|
|
for name, p in wrapped_params_buffers.items():
|
|
|
|
|
|
if name not in new_state_dict:
|
|
|
assert not isinstance(p, torch.nn.Parameter)
|
|
|
|
|
|
if name in new_state_dict:
|
|
|
if name not in unwrapped_params_buffers:
|
|
|
new_state_dict.pop(name)
|
|
|
|
|
|
return gm, new_graph_signature, new_state_dict
|
|
|
|
|
|
old_placeholders = [
|
|
|
node for node in ep.graph_module.graph.nodes if node.op == "placeholder"
|
|
|
]
|
|
|
fake_args = [node.meta["val"] for node in old_placeholders]
|
|
|
|
|
|
buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()]
|
|
|
for name in buffers_to_remove:
|
|
|
delattr(ep.graph_module, name)
|
|
|
|
|
|
|
|
|
fake_mode = detect_fake_mode(fake_args)
|
|
|
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode
|
|
|
custom_triton_ops_decomposition_ctx = (
|
|
|
contextlib.nullcontext
|
|
|
if decompose_custom_triton_ops
|
|
|
else _disable_custom_triton_op_functional_decomposition
|
|
|
)
|
|
|
with (
|
|
|
_ignore_backend_decomps(),
|
|
|
fake_mode,
|
|
|
_override_composite_implicit_decomp(cia_to_decomp),
|
|
|
custom_triton_ops_decomposition_ctx(),
|
|
|
):
|
|
|
gm, graph_signature = aot_export_module(
|
|
|
ep.graph_module,
|
|
|
fake_args,
|
|
|
decompositions=python_decomp_table,
|
|
|
trace_joint=True if joint_loss_index is not None else False,
|
|
|
output_loss_index=(
|
|
|
joint_loss_index if joint_loss_index is not None else None
|
|
|
),
|
|
|
)
|
|
|
gm.graph.eliminate_dead_code()
|
|
|
|
|
|
|
|
|
|
|
|
def update_arg(old_arg, new_ph):
|
|
|
if isinstance(old_arg, ConstantArgument):
|
|
|
return old_arg
|
|
|
elif isinstance(old_arg, TensorArgument):
|
|
|
return TensorArgument(name=new_ph.name)
|
|
|
elif isinstance(old_arg, SymIntArgument):
|
|
|
return SymIntArgument(name=new_ph.name)
|
|
|
elif isinstance(old_arg, SymFloatArgument):
|
|
|
return SymFloatArgument(name=new_ph.name)
|
|
|
elif isinstance(old_arg, SymBoolArgument):
|
|
|
return SymBoolArgument(name=new_ph.name)
|
|
|
raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
|
|
|
|
|
|
new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
|
|
new_outputs = list(gm.graph.nodes)[-1].args[0]
|
|
|
|
|
|
|
|
|
assert len(new_placeholders) == len(old_placeholders)
|
|
|
for old_ph, new_ph in zip(old_placeholders, new_placeholders):
|
|
|
new_ph.name = new_ph.target = old_ph.name
|
|
|
|
|
|
|
|
|
name_map = {ph.name: ph.name for ph in new_placeholders}
|
|
|
for node in gm.graph.nodes:
|
|
|
if node.op == "placeholder":
|
|
|
continue
|
|
|
node.name = _rename_without_collisions(name_map, node.name, node.name)
|
|
|
|
|
|
|
|
|
_name_hoo_subgraph_placeholders(gm)
|
|
|
|
|
|
|
|
|
|
|
|
from torch._export.passes._node_metadata_hook import (
|
|
|
_node_metadata_hook,
|
|
|
_set_node_metadata_hook,
|
|
|
)
|
|
|
from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names
|
|
|
|
|
|
if not torch._dynamo.config.do_not_emit_runtime_asserts:
|
|
|
stack_trace = (
|
|
|
'File "torch/fx/passes/runtime_assert.py", line 24, '
|
|
|
"in insert_deferred_runtime_asserts"
|
|
|
)
|
|
|
shape_env = _get_shape_env(gm)
|
|
|
if shape_env is not None:
|
|
|
with _set_node_metadata_hook(
|
|
|
gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
|
|
|
):
|
|
|
insert_deferred_runtime_asserts(
|
|
|
gm,
|
|
|
shape_env,
|
|
|
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
|
|
|
export=True,
|
|
|
)
|
|
|
|
|
|
|
|
|
gm.recompile()
|
|
|
for i, name in enumerate(_graph_output_names(gm)):
|
|
|
if isinstance(new_outputs[i], torch.fx.Node):
|
|
|
new_outputs[i].name = name
|
|
|
|
|
|
|
|
|
|
|
|
old_new_placeholder_map = {
|
|
|
spec.arg.name: new_placeholders[i].name
|
|
|
for i, spec in enumerate(ep.graph_signature.input_specs)
|
|
|
if not isinstance(spec.arg, ConstantArgument)
|
|
|
}
|
|
|
|
|
|
input_specs = [
|
|
|
InputSpec(
|
|
|
spec.kind,
|
|
|
update_arg(spec.arg, new_placeholders[i]),
|
|
|
spec.target,
|
|
|
spec.persistent,
|
|
|
)
|
|
|
for i, spec in enumerate(ep.graph_signature.input_specs)
|
|
|
]
|
|
|
|
|
|
output_specs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_inputs_index = {name: i for i, name in enumerate(graph_signature.user_inputs)}
|
|
|
mutation_names = list(graph_signature.user_inputs_to_mutate.keys())
|
|
|
assert mutation_names == [node.name for node in new_outputs[: len(mutation_names)]]
|
|
|
for output_name, input_name in graph_signature.user_inputs_to_mutate.items():
|
|
|
i = user_inputs_index[input_name]
|
|
|
input_spec = ep.graph_signature.input_specs[i]
|
|
|
assert input_spec.kind in (InputKind.USER_INPUT, InputKind.BUFFER)
|
|
|
output_kind = (
|
|
|
OutputKind.BUFFER_MUTATION
|
|
|
if input_spec.kind == InputKind.BUFFER
|
|
|
else OutputKind.USER_INPUT_MUTATION
|
|
|
)
|
|
|
target = (
|
|
|
input_spec.target
|
|
|
if input_spec.kind == InputKind.BUFFER
|
|
|
else input_spec.arg.name
|
|
|
)
|
|
|
output_specs.append(
|
|
|
OutputSpec(
|
|
|
kind=output_kind,
|
|
|
arg=TensorArgument(name=output_name),
|
|
|
target=target,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
|
|
|
for i, spec in enumerate(ep.graph_signature.output_specs):
|
|
|
output_specs.append(
|
|
|
OutputSpec(
|
|
|
OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind,
|
|
|
update_arg(spec.arg, new_outputs[len(mutation_names) + i]),
|
|
|
old_new_placeholder_map.get(spec.target, spec.target),
|
|
|
)
|
|
|
)
|
|
|
|
|
|
if joint_loss_index is not None:
|
|
|
assert graph_signature.backward_signature is not None
|
|
|
gradients = graph_signature.backward_signature.gradients_to_user_inputs
|
|
|
assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs)
|
|
|
specs = {
|
|
|
graph_signature.user_inputs[i]: spec
|
|
|
for i, spec in enumerate(ep.graph_signature.input_specs)
|
|
|
if isinstance(spec.arg, TensorArgument)
|
|
|
}
|
|
|
for i, node in enumerate(new_outputs[len(output_specs) :]):
|
|
|
source = gradients[node.name]
|
|
|
spec = specs[source]
|
|
|
if spec.kind == InputKind.PARAMETER:
|
|
|
kind = OutputKind.GRADIENT_TO_PARAMETER
|
|
|
target = spec.target
|
|
|
elif spec.kind == InputKind.USER_INPUT:
|
|
|
kind = OutputKind.GRADIENT_TO_USER_INPUT
|
|
|
target = source
|
|
|
else:
|
|
|
raise AssertionError(f"Unknown input kind: {spec.kind}")
|
|
|
output_specs.append(
|
|
|
OutputSpec(
|
|
|
kind,
|
|
|
TensorArgument(name=node.name),
|
|
|
target,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
assert len(new_placeholders) == len(old_placeholders)
|
|
|
|
|
|
new_graph_signature = ExportGraphSignature(
|
|
|
input_specs=input_specs, output_specs=output_specs
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for old_node, new_node in zip(old_placeholders, new_placeholders):
|
|
|
if not isinstance(old_node.meta["val"], torch.Tensor):
|
|
|
new_node.meta["val"] = old_node.meta["val"]
|
|
|
|
|
|
if (
|
|
|
new_node.target in new_graph_signature.inputs_to_parameters
|
|
|
or new_node.target in new_graph_signature.inputs_to_buffers
|
|
|
):
|
|
|
for k, v in old_node.meta.items():
|
|
|
new_node.meta[k] = v
|
|
|
return gm, new_graph_signature, ep.state_dict
|
|
|
|
|
|
|
|
|
def _remove_unneccessary_copy_op_pass(
|
|
|
gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature
|
|
|
) -> tuple[torch.fx.GraphModule, ExportGraphSignature]:
|
|
|
"""
|
|
|
Removes redundant copy_ node that was introduced due to mutated buffer.
|
|
|
"""
|
|
|
with gm._set_replace_hook(new_graph_signature.get_replace_hook()):
|
|
|
for node in gm.graph.nodes:
|
|
|
if node.op == "output":
|
|
|
args, _ = pytree.tree_flatten(node.args)
|
|
|
for out in args:
|
|
|
if (
|
|
|
isinstance(out, torch.fx.Node)
|
|
|
and out.name in new_graph_signature.buffers_to_mutate
|
|
|
):
|
|
|
if (
|
|
|
out.op == "call_function"
|
|
|
and out.target == torch.ops.aten.copy.default
|
|
|
):
|
|
|
out.replace_all_uses_with(out.args[1])
|
|
|
gm.graph.erase_node(out)
|
|
|
gm.recompile()
|
|
|
return gm, new_graph_signature
|
|
|
|
|
|
|
|
|
def _common_getitem_elimination_pass(
|
|
|
gm: torch.fx.GraphModule, graph_signature, module_call_graph
|
|
|
):
|
|
|
with gm._set_replace_hook(graph_signature.get_replace_hook()):
|
|
|
for module in gm.modules():
|
|
|
if not isinstance(module, torch.fx.GraphModule):
|
|
|
continue
|
|
|
|
|
|
node_id: dict[torch.fx.Node, str] = {}
|
|
|
getitems: dict[str, torch.fx.Node] = {}
|
|
|
for node in list(module.graph.nodes):
|
|
|
if node.op == "call_function" and node.target == operator.getitem:
|
|
|
source, idx = node.args
|
|
|
new_id = f"{node_id[source]}.{idx}"
|
|
|
if new_id in getitems:
|
|
|
node.replace_all_uses_with(getitems[new_id])
|
|
|
for entry in module_call_graph:
|
|
|
if entry.signature is not None:
|
|
|
entry.signature.replace_all_uses_with(
|
|
|
node, getitems[new_id]
|
|
|
)
|
|
|
module.graph.erase_node(node)
|
|
|
else:
|
|
|
getitems[new_id] = node
|
|
|
node_id[node] = new_id
|
|
|
else:
|
|
|
node_id[node] = node.name
|
|
|
|
|
|
|
|
|
def _get_updated_module_call_graph(
|
|
|
old_gm: torch.fx.GraphModule,
|
|
|
old_graph_signature: ExportGraphSignature,
|
|
|
gm: torch.fx.GraphModule,
|
|
|
graph_signature: ExportGraphSignature,
|
|
|
old_module_call_graph: list[ModuleCallEntry],
|
|
|
):
|
|
|
new_module_call_graph = copy.deepcopy(old_module_call_graph)
|
|
|
|
|
|
old_nodes = {node.name: node for node in old_gm.graph.nodes}
|
|
|
|
|
|
old_graph_params_buffers = {
|
|
|
**old_graph_signature.inputs_to_parameters,
|
|
|
**old_graph_signature.inputs_to_buffers,
|
|
|
}
|
|
|
new_graph_params_buffers = {
|
|
|
**graph_signature.inputs_to_parameters,
|
|
|
**graph_signature.inputs_to_buffers,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
provenance: dict[str, str] = {}
|
|
|
|
|
|
user_input_counter = 0
|
|
|
old_user_input_names = [
|
|
|
node.target for node in old_gm.graph.nodes if node.op == "placeholder"
|
|
|
]
|
|
|
old_user_input_names = list(
|
|
|
filter(
|
|
|
lambda x: x not in old_graph_params_buffers
|
|
|
and x not in old_graph_signature.input_tokens,
|
|
|
old_user_input_names,
|
|
|
)
|
|
|
)
|
|
|
new_user_input_names = [
|
|
|
node.target for node in gm.graph.nodes if node.op == "placeholder"
|
|
|
]
|
|
|
|
|
|
for node in gm.graph.nodes:
|
|
|
if history := node.meta.get("from_node", []):
|
|
|
provenance[history[-1].name] = node.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif node.op == "placeholder":
|
|
|
if not (
|
|
|
node.target in new_graph_params_buffers
|
|
|
or node.target in graph_signature.input_tokens
|
|
|
):
|
|
|
if node.target in new_user_input_names:
|
|
|
assert isinstance(node.name, str)
|
|
|
old_name = old_user_input_names[user_input_counter]
|
|
|
assert isinstance(old_name, str)
|
|
|
provenance[old_name] = node.name
|
|
|
user_input_counter += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
old_param_to_desugared = defaultdict(list)
|
|
|
for name, target in new_graph_params_buffers.items():
|
|
|
|
|
|
if not target.startswith("parametrizations."):
|
|
|
|
|
|
if name in old_graph_params_buffers:
|
|
|
provenance[name] = name
|
|
|
else:
|
|
|
old_target = ".".join(target.split(".")[1:-1])
|
|
|
old_param_to_desugared[old_target].append(name)
|
|
|
|
|
|
|
|
|
for entry in new_module_call_graph:
|
|
|
signature = entry.signature
|
|
|
if signature is None:
|
|
|
continue
|
|
|
for x in [*signature.inputs, *signature.outputs]:
|
|
|
|
|
|
|
|
|
if x.name in old_param_to_desugared:
|
|
|
raise ValueError(
|
|
|
f"It looks like {x.name} is a tensor subclass. "
|
|
|
f"Preserving submodule that takes subclass parameter is not supported"
|
|
|
f" in inference IR because we desugar them, resulting in more tensors"
|
|
|
)
|
|
|
|
|
|
if x.name in provenance:
|
|
|
x.name = provenance[x.name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif x.name in old_nodes:
|
|
|
old_node = old_nodes[x.name]
|
|
|
if old_node.op == "call_function" and old_node.target in [
|
|
|
torch.ops.aten.to.dtype_layout,
|
|
|
torch.ops.aten.to.device,
|
|
|
torch.ops.aten.to.dtype,
|
|
|
]:
|
|
|
old_target = old_node.args[0].name
|
|
|
if old_target not in provenance:
|
|
|
raise ValueError(
|
|
|
f"It looks like {old_target} is a tensor subclass. "
|
|
|
f"Preserving submodule that takes subclass parameter is not supported"
|
|
|
f" in inference IR because we desugar them, resulting in more tensors"
|
|
|
)
|
|
|
|
|
|
x.name = provenance[old_target]
|
|
|
|
|
|
return new_module_call_graph
|
|
|
|
|
|
|
|
|
def _decompose_exported_program(
|
|
|
ep,
|
|
|
*,
|
|
|
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
|
|
|
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
|
|
|
joint_loss_index: Optional[int],
|
|
|
decompose_custom_triton_ops: bool,
|
|
|
):
|
|
|
(
|
|
|
gm,
|
|
|
new_graph_signature,
|
|
|
state_dict,
|
|
|
) = _decompose_and_get_gm_with_new_signature_constants(
|
|
|
ep,
|
|
|
cia_to_decomp=cia_to_decomp,
|
|
|
python_decomp_table=python_decomp_table,
|
|
|
joint_loss_index=joint_loss_index,
|
|
|
decompose_custom_triton_ops=decompose_custom_triton_ops,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_module_call_graph = _get_updated_module_call_graph(
|
|
|
ep.graph_module,
|
|
|
ep.graph_signature,
|
|
|
gm,
|
|
|
new_graph_signature,
|
|
|
ep.module_call_graph,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gm.meta.update(ep.graph_module.meta)
|
|
|
|
|
|
new_range_constraints = _get_updated_range_constraints(
|
|
|
gm,
|
|
|
ep.range_constraints,
|
|
|
)
|
|
|
|
|
|
exported_program = ExportedProgram(
|
|
|
root=gm,
|
|
|
graph=gm.graph,
|
|
|
graph_signature=new_graph_signature,
|
|
|
state_dict=state_dict,
|
|
|
range_constraints=new_range_constraints,
|
|
|
module_call_graph=new_module_call_graph,
|
|
|
example_inputs=ep.example_inputs,
|
|
|
constants=ep.constants,
|
|
|
)
|
|
|
return exported_program
|
|
|
|
|
|
|
|
|
class ExportedProgram:
|
|
|
"""
|
|
|
Package of a program from :func:`export`. It contains
|
|
|
an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing
|
|
|
tensor values of all lifted parameters and buffers, and various metadata.
|
|
|
|
|
|
You can call an ExportedProgram like the original callable traced by
|
|
|
:func:`export` with the same calling convention.
|
|
|
|
|
|
To perform transformations on the graph, use ``.module`` property to access
|
|
|
an :class:`torch.fx.GraphModule`. You can then use
|
|
|
`FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_
|
|
|
to rewrite the graph. Afterwards, you can simply use :func:`export`
|
|
|
again to construct a correct ExportedProgram.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
root: Union[torch.nn.Module, dict[str, Any]],
|
|
|
graph: torch.fx.Graph,
|
|
|
graph_signature: ExportGraphSignature,
|
|
|
state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]],
|
|
|
range_constraints: "dict[sympy.Symbol, Any]",
|
|
|
module_call_graph: list[ModuleCallEntry],
|
|
|
example_inputs: Optional[tuple[tuple[Any, ...], dict[str, Any]]] = None,
|
|
|
constants: Optional[dict[str, _ConstantAttributeType]] = None,
|
|
|
*,
|
|
|
verifiers: Optional[list[type[Verifier]]] = None,
|
|
|
):
|
|
|
|
|
|
graph._codegen = torch.fx.graph.CodeGen()
|
|
|
self._graph_module = _create_graph_module_for_export(root, graph)
|
|
|
if isinstance(root, torch.fx.GraphModule):
|
|
|
self._graph_module.meta.update(root.meta)
|
|
|
|
|
|
_common_getitem_elimination_pass(
|
|
|
self._graph_module, graph_signature, module_call_graph
|
|
|
)
|
|
|
self._graph_signature: ExportGraphSignature = graph_signature
|
|
|
self._state_dict: dict[str, Any] = state_dict
|
|
|
self._range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
|
|
|
assert module_call_graph is not None
|
|
|
self._module_call_graph: list[ModuleCallEntry] = module_call_graph
|
|
|
self._example_inputs = example_inputs
|
|
|
|
|
|
self._constants = constants or {}
|
|
|
|
|
|
verifiers = verifiers or [Verifier]
|
|
|
assert all(issubclass(v, Verifier) for v in verifiers)
|
|
|
self._verifiers = verifiers
|
|
|
|
|
|
self.validate()
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def graph_module(self):
|
|
|
return self._graph_module
|
|
|
|
|
|
@graph_module.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def graph_module(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's graph_module attribute.")
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def graph(self):
|
|
|
return self.graph_module.graph
|
|
|
|
|
|
@graph.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def graph(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's graph attribute.")
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def graph_signature(self):
|
|
|
return self._graph_signature
|
|
|
|
|
|
@graph_signature.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def graph_signature(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's graph_signature attribute.")
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def state_dict(self):
|
|
|
return self._state_dict
|
|
|
|
|
|
@state_dict.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def state_dict(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's state_dict attribute.")
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def parameters(self) -> Iterator[torch.nn.Parameter]:
|
|
|
"""
|
|
|
Returns an iterator over original module's parameters.
|
|
|
"""
|
|
|
for _, param in self.named_parameters():
|
|
|
yield param
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def named_parameters(self) -> Iterator[tuple[str, torch.nn.Parameter]]:
|
|
|
"""
|
|
|
Returns an iterator over original module parameters, yielding
|
|
|
both the name of the parameter as well as the parameter itself.
|
|
|
"""
|
|
|
for param_name in self.graph_signature.parameters:
|
|
|
yield param_name, self.state_dict[param_name]
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def buffers(self) -> Iterator[torch.Tensor]:
|
|
|
"""
|
|
|
Returns an iterator over original module buffers.
|
|
|
"""
|
|
|
for _, buf in self.named_buffers():
|
|
|
yield buf
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def named_buffers(self) -> Iterator[tuple[str, torch.Tensor]]:
|
|
|
"""
|
|
|
Returns an iterator over original module buffers, yielding
|
|
|
both the name of the buffer as well as the buffer itself.
|
|
|
"""
|
|
|
non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
|
|
|
for buffer_name in self.graph_signature.buffers:
|
|
|
if buffer_name in non_persistent_buffers:
|
|
|
yield buffer_name, self.constants[buffer_name]
|
|
|
else:
|
|
|
yield buffer_name, self.state_dict[buffer_name]
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def range_constraints(self):
|
|
|
return self._range_constraints
|
|
|
|
|
|
@range_constraints.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def range_constraints(self, value):
|
|
|
raise RuntimeError(
|
|
|
"Unable to set ExportedProgram's range_constraints attribute."
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def module_call_graph(self):
|
|
|
return self._module_call_graph
|
|
|
|
|
|
@module_call_graph.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def module_call_graph(self, value):
|
|
|
raise RuntimeError(
|
|
|
"Unable to set ExportedProgram's module_call_graph attribute."
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def example_inputs(self):
|
|
|
return self._example_inputs
|
|
|
|
|
|
@example_inputs.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def example_inputs(self, value):
|
|
|
|
|
|
|
|
|
if value is None:
|
|
|
self._example_inputs = value
|
|
|
return
|
|
|
|
|
|
if not (
|
|
|
isinstance(value, tuple)
|
|
|
and len(value) == 2
|
|
|
and isinstance(value[0], tuple)
|
|
|
and isinstance(value[1], dict)
|
|
|
):
|
|
|
raise ValueError(
|
|
|
"Example inputs should be a tuple containing example arguments (as "
|
|
|
"a tuple), and example kwargs (as a dictionary)."
|
|
|
)
|
|
|
|
|
|
args, kwargs = value
|
|
|
from ._unlift import _check_inputs_match
|
|
|
|
|
|
_check_inputs_match(args, kwargs, self.call_spec.in_spec)
|
|
|
|
|
|
self._example_inputs = value
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def call_spec(self):
|
|
|
CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
|
|
|
|
|
|
if len(self.module_call_graph) == 0:
|
|
|
return CallSpec(in_spec=None, out_spec=None)
|
|
|
assert self.module_call_graph[0].fqn == ""
|
|
|
return CallSpec(
|
|
|
in_spec=self.module_call_graph[0].signature.in_spec,
|
|
|
out_spec=self.module_call_graph[0].signature.out_spec,
|
|
|
)
|
|
|
|
|
|
@call_spec.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def call_spec(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's call_spec attribute.")
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def verifier(self) -> Any:
|
|
|
return self._verifiers[0]
|
|
|
|
|
|
@verifier.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def verifier(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's verifier attribute.")
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def dialect(self) -> str:
|
|
|
assert self._verifiers is not None
|
|
|
return self._verifiers[0].dialect
|
|
|
|
|
|
@dialect.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def dialect(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's dialect attribute.")
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def verifiers(self):
|
|
|
return self._verifiers
|
|
|
|
|
|
@verifiers.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def verifiers(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's verifiers attribute.")
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def tensor_constants(self):
|
|
|
return self._constants
|
|
|
|
|
|
@tensor_constants.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def tensor_constants(self, value):
|
|
|
raise RuntimeError(
|
|
|
"Unable to set ExportedProgram's tensor_constants attribute."
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def constants(self):
|
|
|
return self._constants
|
|
|
|
|
|
@constants.setter
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def constants(self, value):
|
|
|
raise RuntimeError("Unable to set ExportedProgram's constants attribute.")
|
|
|
|
|
|
def _get_flat_args_with_check(self, args, kwargs):
|
|
|
"""Flatten args, kwargs using pytree, then, check specs.
|
|
|
|
|
|
Args:
|
|
|
args: List[Any] original args passed to __call__
|
|
|
kwargs: Dict[str, Any] original kwargs passed to __call
|
|
|
|
|
|
Returns:
|
|
|
A tuple of (flat_args, received_spec)
|
|
|
flat_args is flattend args / kwargs
|
|
|
received_spec is the pytree spec produced while flattening the
|
|
|
tuple (args, kwargs)
|
|
|
"""
|
|
|
in_spec = self.call_spec.in_spec
|
|
|
if in_spec is not None:
|
|
|
kwargs = reorder_kwargs(kwargs, in_spec)
|
|
|
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
|
|
|
(args, kwargs)
|
|
|
)
|
|
|
self._check_input_constraints(flat_args_with_path)
|
|
|
flat_args = tuple(x[1] for x in flat_args_with_path)
|
|
|
return flat_args, received_spec
|
|
|
|
|
|
def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any:
|
|
|
"""Transform args, kwargs of __call__ to args for graph_module.
|
|
|
|
|
|
self.graph_module takes stuff from state dict as inputs.
|
|
|
The invariant is for ep: ExportedProgram is
|
|
|
ep(args, kwargs) ==
|
|
|
ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
|
|
|
"""
|
|
|
|
|
|
in_spec = self.call_spec.in_spec
|
|
|
flat_args, received_spec = self._get_flat_args_with_check(args, kwargs)
|
|
|
if in_spec is not None and not is_equivalent(
|
|
|
received_spec, in_spec, _fx_collection_equivalence_fn
|
|
|
):
|
|
|
raise ValueError(
|
|
|
"Trying to flatten user inputs with exported input tree spec: \n"
|
|
|
f"{in_spec}\n"
|
|
|
"but actually got inputs with tree spec of: \n"
|
|
|
f"{received_spec}"
|
|
|
)
|
|
|
|
|
|
additional_inputs = []
|
|
|
for input_ in self.graph_signature.input_specs:
|
|
|
if input_.kind == InputKind.USER_INPUT:
|
|
|
continue
|
|
|
elif input_.kind in (
|
|
|
InputKind.PARAMETER,
|
|
|
InputKind.BUFFER,
|
|
|
):
|
|
|
if input_.persistent is False:
|
|
|
|
|
|
|
|
|
additional_inputs.append(self.constants[input_.target])
|
|
|
else:
|
|
|
additional_inputs.append(self.state_dict[input_.target])
|
|
|
elif input_.kind in (
|
|
|
InputKind.CONSTANT_TENSOR,
|
|
|
InputKind.CUSTOM_OBJ,
|
|
|
):
|
|
|
additional_inputs.append(self.constants[input_.target])
|
|
|
additional_inputs = tuple(additional_inputs)
|
|
|
|
|
|
|
|
|
|
|
|
return additional_inputs + flat_args
|
|
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
|
raise RuntimeError(
|
|
|
"Unable to call ExportedProgram directly. "
|
|
|
"You should use `exported_program.module()` instead."
|
|
|
)
|
|
|
|
|
|
def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs):
|
|
|
"""Process potential mutations to the input.
|
|
|
|
|
|
Because self.graph_module is functional, so mutations has to be written
|
|
|
back after execution of graph_module.
|
|
|
"""
|
|
|
import torch._export.error as error
|
|
|
|
|
|
flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs)
|
|
|
if self.call_spec.out_spec is not None:
|
|
|
buffer_mutation = self.graph_signature.buffers_to_mutate
|
|
|
user_input_mutation = self.graph_signature.user_inputs_to_mutate
|
|
|
num_mutated = len(buffer_mutation) + len(user_input_mutation)
|
|
|
mutated_values = res[:num_mutated]
|
|
|
|
|
|
|
|
|
assertion_dep_token = self.graph_signature.assertion_dep_token
|
|
|
if assertion_dep_token is not None:
|
|
|
assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
|
|
|
res = res[:assertion_dep_token_index]
|
|
|
|
|
|
res = res[num_mutated:]
|
|
|
try:
|
|
|
res = pytree.tree_unflatten(res, self.call_spec.out_spec)
|
|
|
except Exception:
|
|
|
_, received_spec = pytree.tree_flatten(res)
|
|
|
raise error.InternalError(
|
|
|
"Trying to flatten user outputs with exported output tree spec: \n"
|
|
|
f"{self.call_spec.out_spec}\n"
|
|
|
"but actually got outputs with tree spec of: \n"
|
|
|
f"{received_spec}"
|
|
|
)
|
|
|
finally:
|
|
|
user_inputs = [
|
|
|
spec
|
|
|
for spec in self.graph_signature.input_specs
|
|
|
if spec.kind == InputKind.USER_INPUT
|
|
|
]
|
|
|
for i, value in enumerate(mutated_values):
|
|
|
output_spec = self.graph_signature.output_specs[i]
|
|
|
if output_spec.kind == OutputKind.BUFFER_MUTATION:
|
|
|
assert output_spec.target is not None
|
|
|
self.state_dict[output_spec.target] = value
|
|
|
elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
|
|
|
assert output_spec.target is not None
|
|
|
index = next(
|
|
|
i
|
|
|
for i, spec in enumerate(user_inputs)
|
|
|
if spec.arg.name == output_spec.target
|
|
|
)
|
|
|
flat_args[index].copy_(value)
|
|
|
else:
|
|
|
raise AssertionError(f"Unexpected kind: {output_spec.kind}")
|
|
|
return res
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
graph_module = self.graph_module.print_readable(
|
|
|
print_output=False, colored=False
|
|
|
).replace("\n", "\n ")
|
|
|
graph_signature = str(self.graph_signature).replace("\n", "\n ")
|
|
|
string = (
|
|
|
"ExportedProgram:\n"
|
|
|
f" {graph_module}\n"
|
|
|
f"Graph signature: {graph_signature}\n"
|
|
|
f"Range constraints: {self.range_constraints}\n"
|
|
|
)
|
|
|
return string
|
|
|
|
|
|
def module(self) -> torch.nn.Module:
|
|
|
"""
|
|
|
Returns a self contained GraphModule with all the parameters/buffers inlined.
|
|
|
"""
|
|
|
from ._unlift import _unlift_exported_program_lifted_states
|
|
|
|
|
|
module = _unlift_exported_program_lifted_states(self)
|
|
|
|
|
|
def _train(self, mode: bool = True):
|
|
|
raise NotImplementedError("Calling train() is not supported yet.")
|
|
|
|
|
|
def _eval(self, mode: bool = True):
|
|
|
raise NotImplementedError("Calling eval() is not supported yet.")
|
|
|
|
|
|
module.train = types.MethodType(_train, module)
|
|
|
module.eval = types.MethodType(_eval, module)
|
|
|
return module
|
|
|
|
|
|
def _num_lifted_params_buffers(self):
|
|
|
return next(
|
|
|
(
|
|
|
i
|
|
|
for i, s in enumerate(self._graph_signature.input_specs)
|
|
|
if s.kind == InputKind.USER_INPUT
|
|
|
),
|
|
|
len(self._graph_signature.input_specs),
|
|
|
)
|
|
|
|
|
|
@_disable_prexisiting_fake_mode
|
|
|
def run_decompositions(
|
|
|
self,
|
|
|
decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
|
|
|
decompose_custom_triton_ops: bool = False,
|
|
|
) -> "ExportedProgram":
|
|
|
"""
|
|
|
Run a set of decompositions on the exported program and returns a new
|
|
|
exported program. By default we will run the Core ATen decompositions to
|
|
|
get operators in the
|
|
|
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
|
|
|
|
|
|
For now, we do not decompose joint graphs.
|
|
|
|
|
|
Args:
|
|
|
decomp_table:
|
|
|
An optional argument that specifies decomp behaviour for Aten ops
|
|
|
(1) If None, we decompose to core aten decompositions
|
|
|
(2) If empty, we don't decompose any operator
|
|
|
|
|
|
|
|
|
Some examples:
|
|
|
|
|
|
If you don't want to decompose anything
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
ep = torch.export.export(model, ...)
|
|
|
ep = ep.run_decompositions(decomp_table={})
|
|
|
|
|
|
If you want to get a core aten operator set except for certain operator, you can do following:
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
ep = torch.export.export(model, ...)
|
|
|
decomp_table = torch.export.default_decompositions()
|
|
|
decomp_table[your_op] = your_custom_decomp
|
|
|
ep = ep.run_decompositions(decomp_table=decomp_table)
|
|
|
"""
|
|
|
_decomp_table = (
|
|
|
default_decompositions() if decomp_table is None else dict(decomp_table)
|
|
|
)
|
|
|
|
|
|
if isinstance(_decomp_table, CustomDecompTable):
|
|
|
_decomp_table = _decomp_table.materialize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(
|
|
|
cia_to_decomp,
|
|
|
python_decomp_table,
|
|
|
) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table)
|
|
|
|
|
|
return _decompose_exported_program(
|
|
|
self,
|
|
|
cia_to_decomp=cia_to_decomp,
|
|
|
python_decomp_table=python_decomp_table,
|
|
|
joint_loss_index=None,
|
|
|
decompose_custom_triton_ops=decompose_custom_triton_ops,
|
|
|
)
|
|
|
|
|
|
def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram":
|
|
|
pm = PassManager(list(passes))
|
|
|
|
|
|
|
|
|
from torch.export._trace import _ignore_backend_decomps
|
|
|
|
|
|
with _ignore_backend_decomps():
|
|
|
res = pm(self.graph_module)
|
|
|
transformed_gm = res.graph_module if res is not None else self.graph_module
|
|
|
assert transformed_gm is not None
|
|
|
|
|
|
if transformed_gm is self.graph_module and not res.modified:
|
|
|
return self
|
|
|
|
|
|
|
|
|
def _get_updated_graph_signature(
|
|
|
old_signature: ExportGraphSignature,
|
|
|
new_gm: torch.fx.GraphModule,
|
|
|
) -> ExportGraphSignature:
|
|
|
"""
|
|
|
Update the graph signature's user_input/user_outputs.
|
|
|
"""
|
|
|
new_input_specs = []
|
|
|
for i, node in enumerate(new_gm.graph.nodes):
|
|
|
if node.op != "placeholder":
|
|
|
break
|
|
|
|
|
|
assert i < len(old_signature.input_specs), (
|
|
|
"Number of inputs changed after transformation"
|
|
|
)
|
|
|
old_input_spec = old_signature.input_specs[i]
|
|
|
arg = (
|
|
|
old_input_spec.arg
|
|
|
if isinstance(
|
|
|
old_input_spec.arg, (ConstantArgument, CustomObjArgument)
|
|
|
)
|
|
|
else type(old_input_spec.arg)(node.name)
|
|
|
)
|
|
|
new_input_specs.append(
|
|
|
InputSpec(
|
|
|
old_input_spec.kind,
|
|
|
arg,
|
|
|
old_input_spec.target,
|
|
|
old_input_spec.persistent,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
output_node = list(new_gm.graph.nodes)[-1]
|
|
|
assert output_node.op == "output"
|
|
|
|
|
|
new_output_specs = []
|
|
|
for i, node in enumerate(output_node.args[0]):
|
|
|
assert i < len(old_signature.output_specs), (
|
|
|
"Number of outputs changed after transformation"
|
|
|
)
|
|
|
old_output_spec = old_signature.output_specs[i]
|
|
|
arg = (
|
|
|
old_output_spec.arg
|
|
|
if isinstance(
|
|
|
old_output_spec.arg, (ConstantArgument, CustomObjArgument)
|
|
|
)
|
|
|
else type(old_output_spec.arg)(node.name)
|
|
|
)
|
|
|
new_output_specs.append(
|
|
|
OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
|
|
|
)
|
|
|
|
|
|
new_signature = ExportGraphSignature(
|
|
|
input_specs=new_input_specs, output_specs=new_output_specs
|
|
|
)
|
|
|
return new_signature
|
|
|
|
|
|
transformed_ep = ExportedProgram(
|
|
|
root=transformed_gm,
|
|
|
graph=transformed_gm.graph,
|
|
|
graph_signature=_get_updated_graph_signature(
|
|
|
self.graph_signature, transformed_gm
|
|
|
),
|
|
|
state_dict=self.state_dict,
|
|
|
range_constraints=_get_updated_range_constraints(
|
|
|
transformed_gm,
|
|
|
self.range_constraints,
|
|
|
),
|
|
|
module_call_graph=copy.deepcopy(self._module_call_graph),
|
|
|
example_inputs=self.example_inputs,
|
|
|
constants=self.constants,
|
|
|
verifiers=self.verifiers,
|
|
|
)
|
|
|
transformed_ep.graph_module.meta.update(self.graph_module.meta)
|
|
|
transformed_ep.graph_module.meta.update(res.graph_module.meta)
|
|
|
return transformed_ep
|
|
|
|
|
|
def _check_input_constraints(self, flat_args_with_path):
|
|
|
from torch._export.utils import _check_input_constraints_for_graph
|
|
|
|
|
|
placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
|
|
|
input_placeholders = [
|
|
|
p
|
|
|
for p, s in zip(placeholders, self.graph_signature.input_specs)
|
|
|
if s.kind == InputKind.USER_INPUT
|
|
|
]
|
|
|
_check_input_constraints_for_graph(
|
|
|
input_placeholders, flat_args_with_path, self.range_constraints
|
|
|
)
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
|
def validate(self):
|
|
|
self._validate()
|
|
|
|
|
|
|
|
|
@final
|
|
|
def _validate(self):
|
|
|
assert len(self.verifiers) > 0, (
|
|
|
"ExportedProgram must have at least one verifier."
|
|
|
)
|
|
|
for v in self.verifiers:
|
|
|
v().check(self)
|
|
|
|
|
|
|
|
|
def _update(
|
|
|
self,
|
|
|
graph_module,
|
|
|
graph_signature,
|
|
|
*,
|
|
|
state_dict=None,
|
|
|
constants=None,
|
|
|
verifiers=None,
|
|
|
) -> "ExportedProgram":
|
|
|
return ExportedProgram(
|
|
|
root=graph_module,
|
|
|
graph=graph_module.graph,
|
|
|
graph_signature=graph_signature,
|
|
|
state_dict=state_dict if state_dict is not None else self.state_dict,
|
|
|
range_constraints=copy.deepcopy(self.range_constraints),
|
|
|
module_call_graph=copy.deepcopy(self._module_call_graph),
|
|
|
example_inputs=self.example_inputs,
|
|
|
constants=constants if constants is not None else self.constants,
|
|
|
verifiers=verifiers if verifiers is not None else self.verifiers,
|
|
|
)
|
|
|
|
|
|
|
|
|
def _get_shape_env(gm):
|
|
|
vals = [
|
|
|
node.meta["val"]
|
|
|
for node in gm.graph.nodes
|
|
|
if node.meta.get("val", None) is not None
|
|
|
]
|
|
|
from torch._guards import detect_fake_mode
|
|
|
|
|
|
fake_mode = detect_fake_mode(vals)
|
|
|
if fake_mode is not None:
|
|
|
return fake_mode.shape_env
|
|
|
for v in vals:
|
|
|
if isinstance(v, torch.SymInt):
|
|
|
return v.node.shape_env
|
|
|
|
|
|
|
|
|
def _get_updated_range_constraints(
|
|
|
gm: torch.fx.GraphModule,
|
|
|
old_range_constraints: "Optional[dict[sympy.Symbol, Any]]" = None,
|
|
|
) -> "dict[sympy.Symbol, Any]":
|
|
|
assert old_range_constraints is not None
|
|
|
|
|
|
shape_env = _get_shape_env(gm)
|
|
|
if shape_env is None:
|
|
|
return {}
|
|
|
|
|
|
range_constraints = copy.copy(old_range_constraints)
|
|
|
range_constraints = {
|
|
|
k: v for k, v in range_constraints.items() if k not in shape_env.replacements
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for k, v in shape_env.var_to_range.items():
|
|
|
if k not in shape_env.replacements and k not in range_constraints:
|
|
|
range_constraints[k] = v
|
|
|
return range_constraints
|
|
|
|
|
|
|
|
|
def _create_graph_module_for_export(root, graph):
|
|
|
try:
|
|
|
gm = torch.fx.GraphModule(root, graph)
|
|
|
except SyntaxError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
"Unable to execute the generated python source code from "
|
|
|
"the graph. The graph module will no longer be directly callable, "
|
|
|
"but you can still run the ExportedProgram, and if needed, you can "
|
|
|
"run the graph module eagerly using torch.fx.Interpreter."
|
|
|
)
|
|
|
gm = torch.fx.GraphModule(root, torch.fx.Graph())
|
|
|
gm._graph = graph
|
|
|
|
|
|
return gm
|
|
|
|