|
|
|
|
|
import ast |
|
|
import copy |
|
|
import dataclasses |
|
|
import functools |
|
|
import inspect |
|
|
import json |
|
|
import math |
|
|
import operator |
|
|
import re |
|
|
from collections import defaultdict |
|
|
from collections.abc import Iterable |
|
|
from contextlib import contextmanager |
|
|
from inspect import ismethod, Parameter |
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union |
|
|
|
|
|
import torch |
|
|
from torch._guards import detect_fake_mode |
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode |
|
|
from torch._subclasses.functional_tensor import FunctionalTensor |
|
|
from torch.fx._utils import first_call_function_nn_module_stack |
|
|
from torch.fx.experimental.proxy_tensor import PreDispatchTorchFunctionMode |
|
|
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from torch._export.passes.lift_constants_pass import ConstantAttrMap |
|
|
from torch._ops import OperatorBase |
|
|
from torch.export import ExportedProgram |
|
|
from torch.export.graph_signature import ExportGraphSignature |
|
|
|
|
|
from torch.export.graph_signature import CustomObjArgument, InputKind, OutputKind |
|
|
from torch.fx._pytree import ( |
|
|
_deregister_pytree_flatten_spec, |
|
|
register_pytree_flatten_spec, |
|
|
) |
|
|
from torch.utils._pytree import ( |
|
|
_deregister_pytree_node, |
|
|
_register_pytree_node, |
|
|
Context, |
|
|
FlattenFunc, |
|
|
FromDumpableContextFn, |
|
|
GetAttrKey, |
|
|
KeyPath, |
|
|
keystr, |
|
|
MappingKey, |
|
|
SequenceKey, |
|
|
ToDumpableContextFn, |
|
|
tree_flatten_with_path, |
|
|
UnflattenFunc, |
|
|
) |
|
|
|
|
|
|
|
|
placeholder_prefixes = { |
|
|
InputKind.USER_INPUT: "", |
|
|
InputKind.PARAMETER: "p_", |
|
|
InputKind.BUFFER: "b_", |
|
|
InputKind.CONSTANT_TENSOR: "c_", |
|
|
InputKind.CUSTOM_OBJ: "obj_", |
|
|
InputKind.TOKEN: "token", |
|
|
} |
|
|
|
|
|
_DISABLE_ATEN_TO_ASSERTION_PASS = False |
|
|
|
|
|
|
|
|
def _collect_and_set_constant_attrs( |
|
|
graph_signature, constants, mod |
|
|
) -> "ConstantAttrMap": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch._export.passes.lift_constants_pass import ConstantAttrMap |
|
|
|
|
|
constant_attrs = ConstantAttrMap() |
|
|
non_persistent_buffers = { |
|
|
spec.target |
|
|
for spec in graph_signature.input_specs |
|
|
if spec.kind == InputKind.BUFFER and not spec.persistent |
|
|
} |
|
|
for name, value in constants.items(): |
|
|
if name in non_persistent_buffers: |
|
|
continue |
|
|
|
|
|
_mod = mod |
|
|
*atoms, attr = name.split(".") |
|
|
for atom in atoms: |
|
|
_mod = getattr(_mod, atom) |
|
|
|
|
|
_mod._buffers.pop(attr, None) |
|
|
setattr(_mod, attr, value) |
|
|
constant_attrs.add(value, name) |
|
|
return constant_attrs |
|
|
|
|
|
|
|
|
def _register_constants_as_buffers( |
|
|
mod: torch.fx.GraphModule, state_dict, non_persistent_buffers |
|
|
): |
|
|
|
|
|
from torch.export.unflatten import _assign_attr, _AttrKind |
|
|
|
|
|
temp_registered_constants = set() |
|
|
|
|
|
for node in mod.graph.nodes: |
|
|
if node.op == "get_attr": |
|
|
target = torch.fx.graph_module._get_attr(mod, node.target) |
|
|
if isinstance(target, torch.Tensor): |
|
|
|
|
|
|
|
|
if (node.target not in state_dict) and ( |
|
|
node.target not in non_persistent_buffers |
|
|
): |
|
|
torch.fx.graph_module._del_attr(mod, node.target) |
|
|
_assign_attr(target, mod, node.target, _AttrKind.BUFFER, False) |
|
|
temp_registered_constants.add(node.target) |
|
|
|
|
|
mod.recompile() |
|
|
|
|
|
return temp_registered_constants |
|
|
|
|
|
|
|
|
def _override_graph_signature_for_temp_registered_constants( |
|
|
sig: "ExportGraphSignature", temp_registered_constants |
|
|
): |
|
|
for spec in sig.input_specs: |
|
|
if spec.target in temp_registered_constants: |
|
|
spec.kind = InputKind.CONSTANT_TENSOR |
|
|
spec.persistent = None |
|
|
|
|
|
for spec in sig.output_specs: |
|
|
if ( |
|
|
spec.kind == OutputKind.BUFFER_MUTATION |
|
|
and spec.target in temp_registered_constants |
|
|
): |
|
|
raise RuntimeError( |
|
|
f"Constant {spec.target} is mutated in the forward method. Pls register it as buffer" |
|
|
) |
|
|
|
|
|
return sig |
|
|
|
|
|
|
|
|
def _overwrite_signature_for_non_persistent_buffers( |
|
|
old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature" |
|
|
): |
|
|
|
|
|
non_persistent_buffers = { |
|
|
spec.target |
|
|
for spec in old_sig.input_specs |
|
|
if spec.kind == InputKind.BUFFER and not spec.persistent |
|
|
} |
|
|
|
|
|
for spec in new_sig.input_specs: |
|
|
if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: |
|
|
spec.persistent = False |
|
|
return new_sig |
|
|
|
|
|
|
|
|
def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> dict[str, Any]: |
|
|
""" |
|
|
Param/buffer metadata needs to be saved before lowering to aten IR |
|
|
because aten IR lifts them, as a result, automatic preservation doesn't work. |
|
|
This is intended to be called on the strict mode tracing right before lowering to |
|
|
aten IR OR run_decomposition pass. |
|
|
""" |
|
|
params_buffers_to_node_meta = {} |
|
|
|
|
|
def _getattr(model: torch.fx.GraphModule, attr_name: str): |
|
|
*prefix, field = attr_name.split(".") |
|
|
t = model |
|
|
for item in prefix: |
|
|
t = getattr(t, item, None) |
|
|
assert t is not None |
|
|
|
|
|
return getattr(t, field) |
|
|
|
|
|
for node in mod.graph.nodes: |
|
|
target = node.target |
|
|
meta = node.meta |
|
|
if node.op == "call_module": |
|
|
submodule = _getattr(mod, target) |
|
|
if isinstance(submodule, torch.nn.Module): |
|
|
for name, _ in submodule.named_parameters( |
|
|
recurse=True, remove_duplicate=False |
|
|
): |
|
|
params_buffers_to_node_meta[target + "." + name] = meta |
|
|
|
|
|
for name, _ in submodule.named_buffers( |
|
|
recurse=True, remove_duplicate=False |
|
|
): |
|
|
params_buffers_to_node_meta[target + "." + name] = meta |
|
|
|
|
|
if node.op == "get_attr": |
|
|
submodule = _getattr(mod, target) |
|
|
if not isinstance(submodule, torch.fx.GraphModule): |
|
|
params_buffers_to_node_meta[target] = meta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if node.op == "call_function" and not isinstance( |
|
|
node.target, torch._ops.HigherOrderOperator |
|
|
): |
|
|
for arg in node._input_nodes: |
|
|
if arg.op == "get_attr": |
|
|
for entry in torch.fx.proxy._COPY_META_FIELDS: |
|
|
|
|
|
if entry == "custom": |
|
|
continue |
|
|
if entry in meta: |
|
|
params_buffers_to_node_meta[arg.target][entry] = meta[entry] |
|
|
|
|
|
return params_buffers_to_node_meta |
|
|
|
|
|
|
|
|
def _maybe_find_pre_dispatch_tf_mode_for_export(): |
|
|
if not torch._C._is_torch_function_mode_enabled(): |
|
|
return None |
|
|
|
|
|
torch_function_mode_stack = torch.overrides._get_current_function_mode_stack() |
|
|
|
|
|
pre_dispatch_tf_modes = [ |
|
|
mode |
|
|
for mode in torch_function_mode_stack |
|
|
if isinstance(mode, PreDispatchTorchFunctionMode) |
|
|
] |
|
|
|
|
|
assert len(pre_dispatch_tf_modes) <= 1, ( |
|
|
f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}" |
|
|
) |
|
|
|
|
|
if len(pre_dispatch_tf_modes) == 0: |
|
|
return None |
|
|
|
|
|
mode = pre_dispatch_tf_modes[0] |
|
|
return mode |
|
|
|
|
|
|
|
|
def _populate_param_buffer_metadata_to_new_gm( |
|
|
params_buffers_to_node_meta: dict[str, Any], |
|
|
gm: torch.fx.GraphModule, |
|
|
new_sig: "ExportGraphSignature", |
|
|
) -> None: |
|
|
""" |
|
|
Given that we collected param'buffer metadata before, we put them back in |
|
|
newly traced graph module |
|
|
""" |
|
|
|
|
|
for metadata in params_buffers_to_node_meta.values(): |
|
|
metadata.pop("nn_module_stack", None) |
|
|
metadata.pop("stack_trace", None) |
|
|
|
|
|
for node in gm.graph.nodes: |
|
|
if node.op == "placeholder": |
|
|
if node.target in new_sig.inputs_to_parameters: |
|
|
param_name = new_sig.inputs_to_parameters[node.target] |
|
|
if param_name in params_buffers_to_node_meta: |
|
|
for k, v in params_buffers_to_node_meta[param_name].items(): |
|
|
node.meta[k] = v |
|
|
if node.target in new_sig.inputs_to_buffers: |
|
|
buffer_name = new_sig.inputs_to_buffers[node.target] |
|
|
if buffer_name in params_buffers_to_node_meta: |
|
|
for k, v in params_buffers_to_node_meta[buffer_name].items(): |
|
|
node.meta[k] = v |
|
|
|
|
|
|
|
|
def _get_shape_env_from_gm(gm: torch.fx.GraphModule): |
|
|
vals = [ |
|
|
node.meta["val"] |
|
|
for node in gm.graph.nodes |
|
|
if node.meta.get("val", None) is not None |
|
|
] |
|
|
|
|
|
fake_mode = _detect_fake_mode_from_gm(gm) |
|
|
if fake_mode is not None: |
|
|
return fake_mode.shape_env |
|
|
for v in vals: |
|
|
if isinstance(v, torch.SymInt): |
|
|
return v.node.shape_env |
|
|
|
|
|
|
|
|
def _rename_without_collisions( |
|
|
name_map: dict[str, str], |
|
|
find_available: dict[str, int], |
|
|
used_names: set[str], |
|
|
orig_name: str, |
|
|
name: str, |
|
|
is_placeholder: bool = False, |
|
|
): |
|
|
""" |
|
|
Renames nodes to avoid name collisions, with suffixing. |
|
|
name_map: map from original name to new name |
|
|
find_available: map prefix to available suffix |
|
|
used_names: cache of used names |
|
|
orig_name: mapping key |
|
|
name: candidate name (potentially suffixed, e.g. mul_2) |
|
|
is_placeholder: if the node is a placeholder, avoid detecting suffix |
|
|
""" |
|
|
match = re.match(r"(.*)_(\d+)", name) |
|
|
key = name |
|
|
|
|
|
if match and not is_placeholder: |
|
|
prefix, n = match.group(1), match.group(2) |
|
|
key = prefix |
|
|
|
|
|
new_name = name |
|
|
if new_name in used_names: |
|
|
new_name = f"{key}_{find_available[key] + 1}" |
|
|
|
|
|
match = re.match(r"(.*)_(\d+)", new_name) |
|
|
if match: |
|
|
prefix, n = match.group(1), match.group(2) |
|
|
if int(n) > find_available[prefix]: |
|
|
find_available[prefix] = int(n) |
|
|
|
|
|
name_map[orig_name] = new_name |
|
|
used_names.add(new_name) |
|
|
|
|
|
return name_map[orig_name] |
|
|
|
|
|
|
|
|
def get_keystr(key_path: KeyPath) -> str: |
|
|
"""For a given index into the flat_args, return a human readable string |
|
|
describing how to access it, e.g. "*args["foo"][0].bar" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
args_kwargs_key_path = key_path[0] |
|
|
assert isinstance(args_kwargs_key_path, SequenceKey) |
|
|
if args_kwargs_key_path.idx == 0: |
|
|
return f"*args{keystr(key_path[1:])}" |
|
|
else: |
|
|
kwarg_key = key_path[1] |
|
|
assert isinstance(kwarg_key, (GetAttrKey, MappingKey)) |
|
|
name = str(kwarg_key)[1:-1] |
|
|
return f"{name}{keystr(key_path[2:])}" |
|
|
|
|
|
|
|
|
def _check_symint( |
|
|
symint: Union[int, torch.SymInt], |
|
|
arg: int, |
|
|
range_constraints, |
|
|
unification_map, |
|
|
keypath: KeyPath, |
|
|
i: Optional[int] = None, |
|
|
) -> None: |
|
|
from torch.export.dynamic_shapes import _IntWrapper |
|
|
|
|
|
if ( |
|
|
isinstance(arg, torch.SymInt) |
|
|
and not arg.node.expr.is_number |
|
|
or isinstance(arg, _IntWrapper) |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
import sympy |
|
|
|
|
|
from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( |
|
|
_convert_range_to_int, |
|
|
) |
|
|
from torch.utils._sympy.solve import try_solve |
|
|
|
|
|
if isinstance(symint, torch.SymInt) and len(symint.node.expr.free_symbols) == 1: |
|
|
symbol = next(iter(symint.node.expr.free_symbols)) |
|
|
if symbol in unification_map: |
|
|
existing_dim = symint.node.expr.subs(unification_map) |
|
|
if arg != existing_dim: |
|
|
path = get_keystr(keypath) |
|
|
if i is not None: |
|
|
path += f".shape[{i}]" |
|
|
raise RuntimeError( |
|
|
f"Expected input at {path} to be equal to {existing_dim}, but got {arg}", |
|
|
) |
|
|
else: |
|
|
if isinstance(symint.node.expr, sympy.Symbol): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unification_map[symbol] = int(arg) |
|
|
else: |
|
|
solution = try_solve(sympy.Eq(symint.node.expr, arg), symbol) |
|
|
if solution is None: |
|
|
path = get_keystr(keypath) |
|
|
if i is not None: |
|
|
path += f".shape[{i}]" |
|
|
raise RuntimeError( |
|
|
f"Expected input {path} = {arg} to be " |
|
|
f"of the form {symint.node.expr}, where {symbol} is an integer" |
|
|
) |
|
|
else: |
|
|
unification_map[symbol] = int(solution[1]) |
|
|
|
|
|
if symint.node.expr in range_constraints: |
|
|
min_val, max_val = _convert_range_to_int( |
|
|
range_constraints[symint.node.expr] |
|
|
) |
|
|
|
|
|
if min_val > 2: |
|
|
if arg < min_val: |
|
|
path = get_keystr(keypath) |
|
|
if i is not None: |
|
|
path += f".shape[{i}]" |
|
|
raise RuntimeError( |
|
|
f"Expected input at {path} to be >= {min_val}, but got {arg}", |
|
|
) |
|
|
if max_val < math.inf: |
|
|
if arg > max_val: |
|
|
path = get_keystr(keypath) |
|
|
if i is not None: |
|
|
path += f".shape[{i}]" |
|
|
raise RuntimeError( |
|
|
f"Expected input at {path} to be <= {max_val}, but got {arg}", |
|
|
) |
|
|
elif isinstance(symint, torch.SymInt) and not symint.node.expr.is_number: |
|
|
|
|
|
|
|
|
pass |
|
|
elif arg != int(symint): |
|
|
path = get_keystr(keypath) |
|
|
if i is not None: |
|
|
path += f".shape[{i}]" |
|
|
raise RuntimeError( |
|
|
f"Expected input at {path} to be equal to {symint}, but got {arg}. " |
|
|
"If you meant for this dimension to be dynamic, please re-export and specify dynamic_shapes " |
|
|
"(e.g. with Dim.DYNAMIC)" |
|
|
) |
|
|
|
|
|
|
|
|
def _check_input_constraints_for_graph( |
|
|
input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints |
|
|
) -> None: |
|
|
import sympy |
|
|
|
|
|
if len(flat_args_with_path) != len(input_placeholders): |
|
|
raise RuntimeError( |
|
|
"Unexpected number of inputs " |
|
|
f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
unification_map: dict[sympy.Symbol, Any] = {} |
|
|
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): |
|
|
node_val = node.meta.get("val") |
|
|
if isinstance(node_val, FakeTensor): |
|
|
if not isinstance(arg, torch.Tensor): |
|
|
raise RuntimeError( |
|
|
f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", |
|
|
) |
|
|
|
|
|
if len(node_val.shape) != len(arg.shape): |
|
|
raise RuntimeError( |
|
|
f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " |
|
|
f"(expected {node_val.shape}, got {arg.shape})" |
|
|
) |
|
|
|
|
|
for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): |
|
|
_check_symint( |
|
|
node_dim, arg_dim, range_constraints, unification_map, key_path, j |
|
|
) |
|
|
|
|
|
elif isinstance(node_val, (int, float, str)): |
|
|
if type(arg) != type(node_val) or arg != node_val: |
|
|
raise RuntimeError( |
|
|
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", |
|
|
) |
|
|
elif isinstance(node_val, torch.SymInt): |
|
|
_check_symint( |
|
|
node_val, arg, range_constraints, unification_map, key_path, None |
|
|
) |
|
|
|
|
|
|
|
|
def register_dataclass_as_pytree_node( |
|
|
cls: type[Any], |
|
|
flatten_fn: Optional[FlattenFunc] = None, |
|
|
unflatten_fn: Optional[UnflattenFunc] = None, |
|
|
*, |
|
|
serialized_type_name: Optional[str] = None, |
|
|
to_dumpable_context: Optional[ToDumpableContextFn] = None, |
|
|
from_dumpable_context: Optional[FromDumpableContextFn] = None, |
|
|
return_none_fields: bool = False, |
|
|
) -> None: |
|
|
assert dataclasses.is_dataclass(cls), ( |
|
|
f"Only dataclasses can be registered with this function: {cls}" |
|
|
) |
|
|
|
|
|
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: |
|
|
flattened = [] |
|
|
flat_names = [] |
|
|
none_names = [] |
|
|
for f in dataclasses.fields(obj): |
|
|
name, val = f.name, getattr(obj, f.name) |
|
|
if val is not None or return_none_fields: |
|
|
flattened.append(val) |
|
|
flat_names.append(name) |
|
|
else: |
|
|
none_names.append(name) |
|
|
return flattened, [flat_names, none_names] |
|
|
|
|
|
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: |
|
|
flat_names, none_names = context |
|
|
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) |
|
|
|
|
|
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: |
|
|
flattened, (flat_names, _none_names) = flatten_fn(obj) |
|
|
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names |
|
|
|
|
|
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn |
|
|
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn |
|
|
|
|
|
if (to_dumpable_context is None) ^ (from_dumpable_context is None): |
|
|
raise ValueError( |
|
|
f"Both to_dumpable_context and from_dumpable_context for {cls} must " |
|
|
"be None or registered." |
|
|
) |
|
|
|
|
|
_register_pytree_node( |
|
|
cls, |
|
|
flatten_fn, |
|
|
unflatten_fn, |
|
|
serialized_type_name=serialized_type_name, |
|
|
flatten_with_keys_fn=default_flatten_fn_with_keys, |
|
|
to_dumpable_context=to_dumpable_context, |
|
|
from_dumpable_context=from_dumpable_context, |
|
|
) |
|
|
|
|
|
|
|
|
def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool: |
|
|
""" |
|
|
Checks if the given node is a parameter within the exported program |
|
|
""" |
|
|
|
|
|
return node.name in program.graph_signature.inputs_to_parameters |
|
|
|
|
|
|
|
|
def get_param( |
|
|
program: "ExportedProgram", |
|
|
node: torch.fx.Node, |
|
|
) -> Optional[torch.nn.Parameter]: |
|
|
""" |
|
|
Returns the parameter associated with the given node in the exported program. |
|
|
Returns None if the node is not a parameter within the exported program |
|
|
""" |
|
|
|
|
|
if is_param(program, node): |
|
|
parameter_name = program.graph_signature.inputs_to_parameters[node.name] |
|
|
return program.state_dict[parameter_name] |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: |
|
|
""" |
|
|
Checks if the given node is a buffer within the exported program |
|
|
""" |
|
|
|
|
|
return node.name in program.graph_signature.inputs_to_buffers |
|
|
|
|
|
|
|
|
def get_buffer( |
|
|
program: "ExportedProgram", |
|
|
node: torch.fx.Node, |
|
|
) -> Optional[torch.Tensor]: |
|
|
""" |
|
|
Returns the buffer associated with the given node in the exported program. |
|
|
Returns None if the node is not a buffer within the exported program |
|
|
""" |
|
|
|
|
|
if is_buffer(program, node): |
|
|
buffer_name = program.graph_signature.inputs_to_buffers[node.name] |
|
|
if buffer_name in program.graph_signature.non_persistent_buffers: |
|
|
return program.constants[buffer_name] |
|
|
else: |
|
|
return program.state_dict[buffer_name] |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def is_lifted_tensor_constant( |
|
|
program: "ExportedProgram", |
|
|
node: torch.fx.Node, |
|
|
) -> bool: |
|
|
""" |
|
|
Checks if the given node is a lifted tensor constant within the exported program |
|
|
""" |
|
|
|
|
|
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants |
|
|
|
|
|
|
|
|
def get_lifted_tensor_constant( |
|
|
program: "ExportedProgram", |
|
|
node: torch.fx.Node, |
|
|
) -> Optional[torch.Tensor]: |
|
|
""" |
|
|
Returns the lifted tensor constant associated with the given node in the exported program. |
|
|
Returns None if the node is not a lifted tensor constant within the exported program |
|
|
""" |
|
|
|
|
|
if is_lifted_tensor_constant(program, node): |
|
|
lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ |
|
|
node.name |
|
|
] |
|
|
return program.constants[lifted_tensor_name] |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def sequential_split( |
|
|
gm: torch.fx.GraphModule, |
|
|
node_call_back: Callable[[torch.fx.Node], Union[torch.fx.Node, bool]], |
|
|
) -> torch.fx.GraphModule: |
|
|
""" |
|
|
sequential_split creates a new graph module that splits the input graph module into multiple submodules |
|
|
based on the node_call_back. It doesn't mutate the input graph module. The node_call_back should return |
|
|
True if the node is a delimiter. Delimiter will be the first node in the next submodule. |
|
|
""" |
|
|
from torch.fx.passes.split_module import split_module |
|
|
|
|
|
split_map = {} |
|
|
split_id = 0 |
|
|
for node in gm.graph.nodes: |
|
|
if node_call_back(node): |
|
|
split_id += 1 |
|
|
split_map[node] = split_id |
|
|
|
|
|
new_gm = split_module( |
|
|
gm, |
|
|
gm, |
|
|
lambda node: split_map[node], |
|
|
keep_original_order=True, |
|
|
keep_original_node_name=True, |
|
|
) |
|
|
|
|
|
new_gm.graph._codegen = gm.graph._codegen |
|
|
new_gm.recompile() |
|
|
return new_gm |
|
|
|
|
|
|
|
|
def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]: |
|
|
"""Returns the nodes that match the node_call_back as a list.""" |
|
|
return [node for node in nodes if node_call_back(node)] |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def _disable_aten_to_metadata_assertions(): |
|
|
global _DISABLE_ATEN_TO_ASSERTION_PASS |
|
|
orig_val = _DISABLE_ATEN_TO_ASSERTION_PASS |
|
|
_DISABLE_ATEN_TO_ASSERTION_PASS = True |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
_DISABLE_ATEN_TO_ASSERTION_PASS = orig_val |
|
|
|
|
|
|
|
|
def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: |
|
|
from torch._export.passes._node_metadata_hook import ( |
|
|
_node_metadata_hook, |
|
|
_set_node_metadata_hook, |
|
|
) |
|
|
|
|
|
if _DISABLE_ATEN_TO_ASSERTION_PASS: |
|
|
return |
|
|
|
|
|
aten_to_variants = [ |
|
|
torch.ops.aten.to.device, |
|
|
torch.ops.aten.to.dtype, |
|
|
torch.ops.aten.to.dtype_layout, |
|
|
] |
|
|
for node in gm.graph.nodes: |
|
|
if node.target in aten_to_variants: |
|
|
if ( |
|
|
node.prev.target == torch.ops.aten._assert_tensor_metadata.default |
|
|
and node.args[0] == node.prev.args[0] |
|
|
): |
|
|
|
|
|
continue |
|
|
|
|
|
if (tensor_val := node.args[0].meta.get("val")) is not None: |
|
|
with ( |
|
|
gm.graph.inserting_before(node), |
|
|
_set_node_metadata_hook( |
|
|
gm, |
|
|
functools.partial( |
|
|
_node_metadata_hook, |
|
|
metadata={ |
|
|
"stack_trace": node.meta.get("stack_trace"), |
|
|
"nn_module_stack": node.meta.get("nn_module_stack"), |
|
|
}, |
|
|
), |
|
|
), |
|
|
): |
|
|
gm.graph.call_function( |
|
|
torch.ops.aten._assert_tensor_metadata.default, |
|
|
args=(node.args[0],), |
|
|
kwargs={ |
|
|
"dtype": tensor_val.dtype, |
|
|
"device": tensor_val.device, |
|
|
"layout": tensor_val.layout, |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): |
|
|
from torch._export.passes._node_metadata_hook import ( |
|
|
_node_metadata_hook, |
|
|
_set_node_metadata_hook, |
|
|
) |
|
|
from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names |
|
|
|
|
|
if not torch._dynamo.config.do_not_emit_runtime_asserts: |
|
|
stack_trace = ( |
|
|
'File "torch/fx/passes/runtime_assert.py", line 24, ' |
|
|
"in insert_deferred_runtime_asserts" |
|
|
) |
|
|
with _set_node_metadata_hook( |
|
|
gm, |
|
|
functools.partial( |
|
|
_node_metadata_hook, metadata={"stack_trace": stack_trace} |
|
|
), |
|
|
): |
|
|
shape_env = _get_shape_env_from_gm(gm) |
|
|
if shape_env: |
|
|
insert_deferred_runtime_asserts( |
|
|
gm, |
|
|
shape_env, |
|
|
f"exported program: {first_call_function_nn_module_stack(gm.graph)}", |
|
|
export=True, |
|
|
) |
|
|
|
|
|
|
|
|
_insert_aten_to_metadata_assert_pass(gm) |
|
|
|
|
|
|
|
|
gm.recompile() |
|
|
graph_signature.user_outputs = _graph_output_names(gm) |
|
|
return gm, graph_signature |
|
|
|
|
|
|
|
|
def nodes_first( |
|
|
nodes: list[torch.fx.Node], node_call_back=None |
|
|
) -> Optional[torch.fx.Node]: |
|
|
""" |
|
|
Returns the first node that matches the node_call_back. If no node matches, returns None. |
|
|
When node_call_back is None, returns the first node in the node list. |
|
|
""" |
|
|
ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) |
|
|
if len(ret) > 0: |
|
|
return ret[0] |
|
|
return None |
|
|
|
|
|
|
|
|
def nodes_count(nodes: list[torch.fx.Node], node_call_back) -> int: |
|
|
"""Returns the number of nodes that match the node_call_back.""" |
|
|
return len(nodes_filter(nodes, node_call_back)) |
|
|
|
|
|
|
|
|
def nodes_map(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]: |
|
|
""" |
|
|
Sequentially visit the nodes list and invoke node_call_back on each element. |
|
|
Returns the nodes list after the node_call_back is invoked on each element. |
|
|
""" |
|
|
for node in nodes: |
|
|
node_call_back(node) |
|
|
return nodes |
|
|
|
|
|
|
|
|
def node_replace_(old_node: torch.fx.Node, new_node: torch.fx.Node) -> None: |
|
|
""" |
|
|
Replace all uses of old_node with new_node. |
|
|
""" |
|
|
old_node.replace_all_uses_with(new_node) |
|
|
old_node.users.clear() |
|
|
old_node.graph.erase_node(old_node) |
|
|
|
|
|
|
|
|
def _update_gm_meta_if_possible(gm: torch.fx.GraphModule, mod: torch.nn.Module) -> None: |
|
|
if ( |
|
|
isinstance(mod, torch.fx.GraphModule) |
|
|
and hasattr(mod, "meta") |
|
|
and "custom" in mod.meta |
|
|
): |
|
|
gm.meta.update({"custom": mod.meta["custom"]}) |
|
|
|
|
|
|
|
|
def node_inline_(call_mod_node: torch.fx.Node) -> Optional[torch.fx.GraphModule]: |
|
|
""" |
|
|
Inline the submodule of the given node into the parent module. |
|
|
Note: we only support the case where submodule takes tensors inputs. |
|
|
""" |
|
|
assert call_mod_node.op == "call_module" |
|
|
gm = call_mod_node.graph.owning_module |
|
|
assert gm is not None |
|
|
|
|
|
assert isinstance(call_mod_node.target, str) |
|
|
sub_gm = getattr(gm, call_mod_node.target) |
|
|
|
|
|
phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") |
|
|
body = ( |
|
|
node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") |
|
|
) |
|
|
output = [node for node in sub_gm.graph.nodes if node.op == "output"] |
|
|
|
|
|
for ph, arg in zip(phs, call_mod_node.args): |
|
|
assert isinstance(arg, torch.fx.Node) |
|
|
node_replace_(ph, arg) |
|
|
|
|
|
with gm.graph.inserting_before(call_mod_node): |
|
|
for node in body: |
|
|
new_node = gm.graph.node_copy(node) |
|
|
if node.op == "get_attr": |
|
|
new_target_name = new_node.target |
|
|
if hasattr(gm, new_target_name): |
|
|
|
|
|
i = 1 |
|
|
new_target_name = f"submod_{i}" |
|
|
while hasattr(gm, new_target_name): |
|
|
i += 1 |
|
|
new_target_name = f"submod_{i}" |
|
|
new_node.target = new_target_name |
|
|
setattr(gm, new_node.target, getattr(sub_gm, node.target)) |
|
|
node_replace_(node, new_node) |
|
|
|
|
|
if len(output) > 0: |
|
|
assert len(output) == 1 and len(output[0].args) == 1 |
|
|
new_output = output[0].args[0] |
|
|
|
|
|
if isinstance(new_output, torch.fx.Node): |
|
|
|
|
|
|
|
|
new_output.users.clear() |
|
|
node_replace_(call_mod_node, new_output) |
|
|
elif isinstance(new_output, (list, tuple)): |
|
|
|
|
|
for node in new_output: |
|
|
node.users.pop(output[0]) |
|
|
|
|
|
|
|
|
get_item_users = nodes_filter( |
|
|
list(call_mod_node.users.keys()), |
|
|
lambda node: node.op == "call_function" |
|
|
and node.target == operator.getitem, |
|
|
) |
|
|
|
|
|
nodes_map( |
|
|
get_item_users, |
|
|
lambda get_item_node: node_replace_( |
|
|
get_item_node, |
|
|
new_output[get_item_node.args[1]], |
|
|
), |
|
|
) |
|
|
call_mod_node.graph.erase_node(call_mod_node) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." |
|
|
) |
|
|
else: |
|
|
call_mod_node.graph.erase_node(call_mod_node) |
|
|
|
|
|
gm.delete_all_unused_submodules() |
|
|
gm.recompile() |
|
|
return gm |
|
|
|
|
|
|
|
|
def _get_torch_jit_trace_forward_signature(mod: torch.nn.Module) -> inspect.Signature: |
|
|
""" |
|
|
Get source code and parse argument names using AST. The function returns |
|
|
a signature of the forward() function. |
|
|
|
|
|
# TODO: Directly provide inspect.signature compatible TS-d module. |
|
|
""" |
|
|
ast_mod = ast.parse(mod.code) |
|
|
ast_func_def: ast.FunctionDef = ast_mod.body[0] |
|
|
|
|
|
|
|
|
arg_type_map = {"args": Parameter.POSITIONAL_OR_KEYWORD} |
|
|
|
|
|
|
|
|
param_list = [] |
|
|
for arg_type, param_type in arg_type_map.items(): |
|
|
arg_name_list = [a.arg for a in getattr(ast_func_def.args, arg_type)] |
|
|
for arg_name in arg_name_list: |
|
|
if arg_name == "self": |
|
|
continue |
|
|
param_list.append(inspect.Parameter(arg_name, param_type)) |
|
|
|
|
|
return inspect.Signature(parameters=param_list) |
|
|
|
|
|
|
|
|
def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): |
|
|
if isinstance(mod, (torch.jit.ScriptModule, torch.jit.TracedModule)): |
|
|
sig = _get_torch_jit_trace_forward_signature(mod) |
|
|
|
|
|
|
|
|
assert len(sig.parameters) == len(fake_args) + len(fake_kwargs), ( |
|
|
"Arguments other than POSITIONAL_OR_KEYWORD kinds in forward() " |
|
|
"are not supported in _get_torch_jit_trace_forward_signature" |
|
|
) |
|
|
else: |
|
|
sig = inspect.signature(mod.forward) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs} |
|
|
|
|
|
|
|
|
def _build_cache(name, find_available, used_names): |
|
|
used_names.add(name) |
|
|
match = re.match(r"(.*)_(\d+)", name) |
|
|
if match: |
|
|
prefix, n = match.group(1), match.group(2) |
|
|
if int(n) > find_available[prefix]: |
|
|
find_available[prefix] = int(n) |
|
|
|
|
|
|
|
|
def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: |
|
|
""" |
|
|
Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, |
|
|
and handle collisions with non-placeholders by count suffixing. |
|
|
Different HOO subgraph types have different input schemas, so we first enumerate them |
|
|
and gather the top-level named placeholder nodes. |
|
|
""" |
|
|
|
|
|
|
|
|
subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = [] |
|
|
for node in gm.graph.nodes: |
|
|
if node.op == "call_function" and isinstance( |
|
|
node.target, torch._ops.HigherOrderOperator |
|
|
): |
|
|
|
|
|
if node.target._name == "cond": |
|
|
_, true_graph, false_graph, cond_args = node._args |
|
|
subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) |
|
|
subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) |
|
|
elif node.target._name == "wrap_with_set_grad_enabled": |
|
|
subgraph, phs = node._args[1], node._args[2:] |
|
|
subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) |
|
|
elif node.target._name == "map_impl": |
|
|
body_graph, array, args = node._args |
|
|
subgraph_ph_tuples.append( |
|
|
(getattr(gm, body_graph.target), array + args) |
|
|
) |
|
|
|
|
|
|
|
|
for subgraph, hoo_phs in subgraph_ph_tuples: |
|
|
name_map: dict[str, str] = {} |
|
|
find_available: dict[str, int] = defaultdict(int) |
|
|
used_names: set[str] = set() |
|
|
for i, node in enumerate(subgraph.graph.nodes): |
|
|
if i < len(hoo_phs): |
|
|
name_map[node.name] = hoo_phs[i].name |
|
|
node.name = node.target = hoo_phs[i].name |
|
|
_build_cache(node.name, find_available, used_names) |
|
|
else: |
|
|
node.name = _rename_without_collisions( |
|
|
name_map, find_available, used_names, node.name, node.name |
|
|
) |
|
|
|
|
|
|
|
|
_name_hoo_subgraph_placeholders(subgraph) |
|
|
subgraph.recompile() |
|
|
|
|
|
|
|
|
def placeholder_naming_pass( |
|
|
gm: torch.fx.GraphModule, |
|
|
export_graph_signature: "ExportGraphSignature", |
|
|
mod: torch.nn.Module, |
|
|
fake_args, |
|
|
fake_kwargs, |
|
|
fake_params_buffers, |
|
|
constants: dict[str, Any], |
|
|
) -> None: |
|
|
""" |
|
|
This pass is run at the end of _export_non_strict() to assign better placeholder node names: |
|
|
- User inputs: |
|
|
These follow the signature of mod.forward(), e.g. forward(x, y) produces nodes x, y. |
|
|
For nested inputs from dictionaries, lists, tuples, or dataclasses, |
|
|
the names are a concatenation of the path to the tensor. |
|
|
e.g. x = { |
|
|
'a': torch.randn(), |
|
|
'b': [torch.randn(), torch.randn()] |
|
|
} |
|
|
produces nodes x_a, x_b_0, x_b_1. |
|
|
- Parameters/buffers/constants/custom objects: |
|
|
These follow the FQN of the object, prefixed by "p", "b", "c", "obj" respectively. |
|
|
e.g. self.bar.l0.weight produces "p_bar_l0_weight". |
|
|
- Effect tokens: |
|
|
These are named token, token_1, ... |
|
|
""" |
|
|
|
|
|
custom_meta: dict[str, Any] = {} |
|
|
if isinstance(mod, torch.fx.GraphModule): |
|
|
for node in mod.graph.nodes: |
|
|
if "custom" in node.meta: |
|
|
custom_meta[node.name] = node.meta["custom"] |
|
|
|
|
|
def _strip_name(x): |
|
|
if x.startswith("L__self___"): |
|
|
x = x[len("L__self___") :] |
|
|
elif x.startswith("self_"): |
|
|
x = x[len("self_") :] |
|
|
x = re.sub(r"[^a-zA-Z0-9]", "_", x) |
|
|
return x |
|
|
|
|
|
def _extract_pytree_key(x): |
|
|
if isinstance(x, MappingKey): |
|
|
x = re.sub(r"[^a-zA-Z0-9]", "_", str(x.key)) |
|
|
return x |
|
|
elif isinstance(x, SequenceKey): |
|
|
return str(x.idx) |
|
|
elif isinstance(x, GetAttrKey): |
|
|
return x.name |
|
|
else: |
|
|
raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") |
|
|
|
|
|
name_map: dict[str, str] = {} |
|
|
find_available: dict[str, int] = defaultdict(int) |
|
|
used_names: set[str] = set() |
|
|
|
|
|
|
|
|
combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) |
|
|
|
|
|
flat_args_with_path, _ = tree_flatten_with_path(combined_args) |
|
|
user_input_names = [ |
|
|
spec.arg.name |
|
|
for spec in export_graph_signature.input_specs |
|
|
if spec.kind == InputKind.USER_INPUT |
|
|
] |
|
|
|
|
|
|
|
|
for (arg_path, _arg), user_input_name in zip(flat_args_with_path, user_input_names): |
|
|
if user_input_name: |
|
|
_rename_without_collisions( |
|
|
name_map, |
|
|
find_available, |
|
|
used_names, |
|
|
user_input_name, |
|
|
placeholder_prefixes[InputKind.USER_INPUT] |
|
|
+ "_".join(_extract_pytree_key(x).lower() for x in arg_path), |
|
|
is_placeholder=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for spec in export_graph_signature.input_specs: |
|
|
if spec.kind == InputKind.USER_INPUT: |
|
|
continue |
|
|
if spec.kind == InputKind.TOKEN: |
|
|
base_name = "" |
|
|
else: |
|
|
base_name = _strip_name(spec.target).lower() |
|
|
base_name = re.sub(r"[^a-zA-Z0-9]", "_", base_name) |
|
|
|
|
|
_rename_without_collisions( |
|
|
name_map, |
|
|
find_available, |
|
|
used_names, |
|
|
spec.arg.name, |
|
|
placeholder_prefixes[spec.kind] + base_name, |
|
|
is_placeholder=True, |
|
|
) |
|
|
if base_name in custom_meta: |
|
|
|
|
|
|
|
|
|
|
|
custom_meta[name_map[spec.arg.name]] = custom_meta[base_name] |
|
|
del custom_meta[base_name] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for node in gm.graph.nodes: |
|
|
if node.op == "placeholder": |
|
|
continue |
|
|
_rename_without_collisions( |
|
|
name_map, find_available, used_names, node.name, node.name |
|
|
) |
|
|
|
|
|
|
|
|
for node in gm.graph.nodes: |
|
|
if node.op == "placeholder": |
|
|
assert node.name in name_map |
|
|
node.name = node.target = name_map[node.name] |
|
|
if node.name in custom_meta: |
|
|
if node.meta.get("custom") is None: |
|
|
node.meta["custom"] = custom_meta[node.name] |
|
|
else: |
|
|
assert node.meta["custom"] == custom_meta[node.name] |
|
|
|
|
|
|
|
|
if isinstance(node.meta["val"], CustomObjArgument): |
|
|
node.meta["val"].name = node.name |
|
|
elif node.name in name_map: |
|
|
node.name = name_map[node.name] |
|
|
|
|
|
|
|
|
_name_hoo_subgraph_placeholders(gm) |
|
|
|
|
|
|
|
|
gm.recompile() |
|
|
|
|
|
|
|
|
for spec in export_graph_signature.input_specs: |
|
|
assert spec.arg.name in name_map |
|
|
spec.arg.name = name_map[spec.arg.name] |
|
|
if ( |
|
|
spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map |
|
|
): |
|
|
spec.target = name_map[spec.target][4:] |
|
|
|
|
|
for spec in export_graph_signature.output_specs: |
|
|
if spec.arg.name in name_map: |
|
|
spec.arg.name = name_map[spec.arg.name] |
|
|
if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map: |
|
|
spec.target = name_map[spec.target] |
|
|
|
|
|
|
|
|
for name in list(constants.keys()): |
|
|
constant = constants[name] |
|
|
if name in name_map and not isinstance( |
|
|
constant, torch.Tensor |
|
|
): |
|
|
new_name = name_map[name] |
|
|
if ( |
|
|
new_name != name |
|
|
and re.match(r"arg(\d+)_1", name) |
|
|
and new_name != placeholder_prefixes[InputKind.CUSTOM_OBJ] + name |
|
|
): |
|
|
constants[new_name] = constant |
|
|
del constants[name] |
|
|
|
|
|
|
|
|
def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict: |
|
|
""" |
|
|
If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`. |
|
|
`v` is the values in the dictionary. |
|
|
If `in_place` is true, modify `state_dict` in place. |
|
|
""" |
|
|
if in_place: |
|
|
for k, v in state_dict.items(): |
|
|
if hasattr(v, "proxy"): |
|
|
delattr(state_dict[k], "proxy") |
|
|
return state_dict |
|
|
else: |
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if hasattr(v, "proxy"): |
|
|
new_state_dict[k] = v.detach().clone() |
|
|
else: |
|
|
new_state_dict[k] = v |
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
def _detect_fake_mode_from_gm( |
|
|
gm: torch.fx.GraphModule, |
|
|
) -> Optional[torch._subclasses.fake_tensor.FakeTensorMode]: |
|
|
""" |
|
|
For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs. |
|
|
Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. |
|
|
If no fake mode is found, we return None for fake_mode. |
|
|
""" |
|
|
|
|
|
fake_inps: list[torch.Tensor] = [] |
|
|
fake_vals: list[torch.Tensor] = [] |
|
|
for node in gm.graph.nodes: |
|
|
if node.op == "placeholder" and "val" in node.meta: |
|
|
fake_val = node.meta["val"] |
|
|
if fake_val is not None and isinstance(fake_val, torch.Tensor): |
|
|
fake_inps.append(fake_val) |
|
|
elif len(fake_inps) == 0 and ( |
|
|
"example_value" in node.meta or "val" in node.meta |
|
|
): |
|
|
fake_val = None |
|
|
if "example_value" in node.meta: |
|
|
fake_val = node.meta["example_value"] |
|
|
elif "val" in node.meta: |
|
|
fake_val = node.meta["val"] |
|
|
if fake_val is not None and isinstance(fake_val, torch.Tensor): |
|
|
fake_vals.append(fake_val) |
|
|
|
|
|
return detect_fake_mode(fake_inps + fake_vals) |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def _disable_load_state_dict_hooks(mod: torch.nn.Module): |
|
|
state_dict_hooks: dict[int, Callable] = dict(mod._state_dict_hooks) |
|
|
state_dict_pre_hooks: dict[int, Callable] = dict(mod._state_dict_pre_hooks) |
|
|
mod._state_dict_hooks.clear() |
|
|
mod._state_dict_pre_hooks.clear() |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
mod._state_dict_hooks = state_dict_hooks |
|
|
mod._state_dict_pre_hooks = state_dict_pre_hooks |
|
|
|
|
|
|
|
|
def _is_cia_op(op: "OperatorBase") -> bool: |
|
|
return ( |
|
|
torch._C._dispatch_has_kernel_for_dispatch_key( |
|
|
op.name(), torch._C.DispatchKey.CompositeImplicitAutograd |
|
|
) |
|
|
or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels |
|
|
) |
|
|
|
|
|
|
|
|
def _is_preservable_cia_op(op: "OperatorBase") -> bool: |
|
|
return _check_valid_to_preserve(op) and _is_cia_op(op) |
|
|
|
|
|
|
|
|
def _is_aten_op(op: "OperatorBase") -> bool: |
|
|
return op.name().split("::")[0] == "aten" |
|
|
|
|
|
|
|
|
def _is_custom_op(op: "OperatorBase") -> bool: |
|
|
return not _is_aten_op(op) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _materialize_cpp_cia_ops() -> None: |
|
|
""" |
|
|
Utility function to query C++ dispatcher to get the all |
|
|
possible CIA ops and populate them into torch.ops namespace |
|
|
""" |
|
|
cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key( |
|
|
"CompositeImplicitAutograd" |
|
|
) |
|
|
|
|
|
|
|
|
for op in cia_ops: |
|
|
namespace, op_name = tuple(op.split("::")) |
|
|
split_list = op_name.split(".") |
|
|
|
|
|
assert len(split_list) == 1 or len(split_list) == 2 |
|
|
op_name = split_list[0] |
|
|
op_overload_name = "default" |
|
|
if len(split_list) == 2: |
|
|
op_overload_name = split_list[1] |
|
|
|
|
|
_ = getattr(getattr(getattr(torch.ops, namespace), op_name), op_overload_name) |
|
|
|
|
|
|
|
|
def _special_op_to_preserve_cia(*args, **kwargs): |
|
|
""" |
|
|
This is an special marker that tells our infra that we shouldn't decompose this op. |
|
|
""" |
|
|
return NotImplemented |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_valid_to_preserve(op_overload: "OperatorBase"): |
|
|
from torch._decomp import _should_decompose_because_unsafe_op |
|
|
|
|
|
if _should_decompose_because_unsafe_op(op_overload): |
|
|
return False |
|
|
if op_overload in FunctionalTensor.metadata_fns: |
|
|
return False |
|
|
|
|
|
if not hasattr(op_overload, "_schema"): |
|
|
return False |
|
|
|
|
|
alias_info = len( |
|
|
[i for i in op_overload._schema.arguments if i.alias_info is not None] |
|
|
) |
|
|
|
|
|
is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable |
|
|
|
|
|
if is_mutating_or_aliasing: |
|
|
return False |
|
|
|
|
|
if not torch._C._dispatch_has_kernel(op_overload.name()): |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=1) |
|
|
def _collect_all_valid_cia_ops_for_aten_namespace() -> set["OperatorBase"]: |
|
|
return _collect_all_valid_cia_ops_for_namespace(torch.ops.aten) |
|
|
|
|
|
|
|
|
def _collect_all_valid_cia_ops_for_namespace( |
|
|
op_namespace: torch._ops._OpNamespace, |
|
|
) -> set["OperatorBase"]: |
|
|
|
|
|
_materialize_cpp_cia_ops() |
|
|
|
|
|
|
|
|
cia_ops = set() |
|
|
for op in op_namespace: |
|
|
op_packet = getattr(op_namespace, op) |
|
|
for overload in op_packet.overloads(): |
|
|
op_overload = getattr(op_packet, overload) |
|
|
if _is_preservable_cia_op(op_overload): |
|
|
cia_ops.add(op_overload) |
|
|
return cia_ops |
|
|
|
|
|
|
|
|
def _collect_all_valid_cia_ops() -> set["OperatorBase"]: |
|
|
""" |
|
|
This is an util function that gets the all CIA functional ops. |
|
|
|
|
|
The algorithm is in 2 steps: |
|
|
1. We first query C++ dispatcher to get the list of CIA ops |
|
|
and then we call getattr on torch.ops.aten to lazily populate |
|
|
them. |
|
|
|
|
|
2. Sometimes, handful of ops have CIA registered in python dispatcher |
|
|
but not on the C++ side, these can't be caught at the first step. |
|
|
So we walk again to get the final list. |
|
|
|
|
|
Note that the output of this function should never be modified |
|
|
""" |
|
|
cia_ops = set() |
|
|
for op_namespace_name in torch.ops._dir: |
|
|
|
|
|
if op_namespace_name != "aten": |
|
|
assert hasattr(torch.ops, op_namespace_name) |
|
|
op_namespace = getattr(torch.ops, op_namespace_name) |
|
|
if isinstance(op_namespace, torch._ops._OpNamespace): |
|
|
cia_ops |= _collect_all_valid_cia_ops_for_namespace(op_namespace) |
|
|
else: |
|
|
cia_ops |= _collect_all_valid_cia_ops_for_aten_namespace() |
|
|
return cia_ops |
|
|
|
|
|
|
|
|
def _get_decomp_for_cia(op: "OperatorBase"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dk = torch._C.DispatchKey.CompositeImplicitAutograd |
|
|
if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey): |
|
|
return op.py_kernels[dk] |
|
|
|
|
|
def _special_op_to_decompose_cia(*args, **kwargs): |
|
|
kernel = kwargs["kernel"] |
|
|
del kwargs["kernel"] |
|
|
|
|
|
|
|
|
dk = torch._C.DispatchKey.CompositeImplicitAutograd |
|
|
if torch._C._dispatch_has_kernel_for_dispatch_key( |
|
|
kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd |
|
|
): |
|
|
return kernel._op_dk(dk, *args, **kwargs) |
|
|
else: |
|
|
raise AssertionError( |
|
|
f"Expected {kernel} to have CompositeImplicitAutograd kernel" |
|
|
) |
|
|
|
|
|
return functools.partial(_special_op_to_decompose_cia, kernel=op) |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def _compiling_state_context(): |
|
|
old_compiling_flag = torch.compiler._is_compiling_flag |
|
|
old_exporting_flag = torch.compiler._is_exporting_flag |
|
|
try: |
|
|
torch.compiler._is_compiling_flag = True |
|
|
torch.compiler._is_exporting_flag = True |
|
|
yield |
|
|
finally: |
|
|
torch.compiler._is_compiling_flag = old_compiling_flag |
|
|
torch.compiler._is_exporting_flag = old_exporting_flag |
|
|
|
|
|
|
|
|
def _fakify_params_buffers( |
|
|
fake_mode: FakeTensorMode, |
|
|
mod: torch.nn.Module, |
|
|
) -> dict[str, Union[torch.Tensor, torch.nn.Parameter]]: |
|
|
params_buffers = { |
|
|
**dict(mod.named_parameters(remove_duplicate=False)), |
|
|
**dict(mod.named_buffers(remove_duplicate=False)), |
|
|
} |
|
|
|
|
|
faked_params_buffers = {} |
|
|
memo: dict[int, FakeTensor] = {} |
|
|
for key, value in params_buffers.items(): |
|
|
if id(value) in memo: |
|
|
fake_tensor = memo[id(value)] |
|
|
else: |
|
|
fake_tensor = fake_mode.from_tensor(value, static_shapes=True) |
|
|
memo[id(value)] = fake_tensor |
|
|
faked_params_buffers[key] = fake_tensor |
|
|
return faked_params_buffers |
|
|
|
|
|
|
|
|
def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: |
|
|
""" |
|
|
Registers a module as a valid input type for :func:`torch.export.export`. |
|
|
|
|
|
Args: |
|
|
mod: the module instance |
|
|
serialized_type_name: The serialized name for the module. This is |
|
|
required if you want to serialize the pytree TreeSpec containing this |
|
|
module. |
|
|
|
|
|
Example:: |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class Module(torch.nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.linear = torch.nn.Linear(3, 3) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.linear(x) |
|
|
|
|
|
|
|
|
torch._export.utils.register_module_as_pytree_node(InputDataClass) |
|
|
|
|
|
|
|
|
class Mod(torch.nn.Module): |
|
|
def forward(self, x, m): |
|
|
return m(x) + x |
|
|
|
|
|
|
|
|
ep = torch.export.export(Mod(), (torch.randn(3), Module())) |
|
|
print(ep) |
|
|
|
|
|
""" |
|
|
assert issubclass(cls, torch.nn.Module) |
|
|
|
|
|
import weakref |
|
|
|
|
|
class PrototypeModule(weakref.ref): |
|
|
def __init__(self, m, *args, **kwargs): |
|
|
super().__init__(m, *args, **kwargs) |
|
|
assert isinstance(m, torch.nn.Module) |
|
|
assert not hasattr(self, "_proto_cls") |
|
|
self._proto_cls = cls |
|
|
|
|
|
def __eq__(self, other): |
|
|
return self._proto_cls == other._proto_cls |
|
|
|
|
|
def __deepcopy__(self, memo): |
|
|
return PrototypeModule(self()) |
|
|
|
|
|
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: |
|
|
named_parameters = dict(obj.named_parameters()) |
|
|
named_buffers = dict(obj.named_buffers()) |
|
|
params_buffers = {**named_parameters, **named_buffers} |
|
|
return list(params_buffers.values()), [ |
|
|
list(params_buffers.keys()), |
|
|
PrototypeModule(obj), |
|
|
] |
|
|
|
|
|
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: |
|
|
flat_names, ref = context |
|
|
if ref is None or ref() is None: |
|
|
raise RuntimeError("Module has been garbage collected") |
|
|
obj = ref() |
|
|
assert flatten_fn is not None |
|
|
flattened, _ = flatten_fn(obj) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_module(mod: torch.nn.Module): |
|
|
ret = copy.copy(mod) |
|
|
ret.__dict__ = {copy.copy(k): copy.copy(v) for k, v in mod.__dict__.items()} |
|
|
for name, child in ret.named_children(): |
|
|
setattr(ret, name, copy_module(child)) |
|
|
return ret |
|
|
|
|
|
if any(v is not o for v, o in zip(values, flattened)): |
|
|
with torch.nn.utils.stateless._reparametrize_module( |
|
|
obj, dict(zip(flat_names, values)), tie_weights=True, strict=True |
|
|
): |
|
|
ret = copy_module(obj) |
|
|
else: |
|
|
ret = obj |
|
|
return ret |
|
|
|
|
|
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: |
|
|
flattened, [flat_names, *args] = flatten_fn(obj) |
|
|
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [ |
|
|
flat_names, |
|
|
*args, |
|
|
] |
|
|
|
|
|
flatten_fn = default_flatten_fn |
|
|
unflatten_fn = default_unflatten_fn |
|
|
|
|
|
serialized_type_name = cls.__module__ + "." + cls.__qualname__ |
|
|
|
|
|
def to_dumpable_context(context): |
|
|
keys, *_ = context |
|
|
return json.dumps([keys, *([None] * len(_))]) |
|
|
|
|
|
def from_dumpable_context(dumpable): |
|
|
s = json.loads(dumpable) |
|
|
s[1] = PrototypeModule(torch.nn.Module()) |
|
|
return s |
|
|
|
|
|
_register_pytree_node( |
|
|
cls, |
|
|
flatten_fn, |
|
|
unflatten_fn, |
|
|
serialized_type_name=serialized_type_name, |
|
|
flatten_with_keys_fn=default_flatten_fn_with_keys, |
|
|
to_dumpable_context=to_dumpable_context, |
|
|
from_dumpable_context=from_dumpable_context, |
|
|
) |
|
|
|
|
|
def default_flatten_fn_spec(obj, spec) -> list[Any]: |
|
|
flats, context = flatten_fn(obj) |
|
|
assert context == spec.context |
|
|
return flats |
|
|
|
|
|
register_pytree_flatten_spec( |
|
|
cls, |
|
|
default_flatten_fn_spec, |
|
|
) |
|
|
|
|
|
|
|
|
def deregister_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: |
|
|
_deregister_pytree_node(cls) |
|
|
_deregister_pytree_flatten_spec(cls) |
|
|
|
|
|
|
|
|
def _sync_state(src, dst): |
|
|
assert isinstance( |
|
|
src, |
|
|
torch.nn.Module, |
|
|
), f"Expected {src} to be a nn.Module" |
|
|
assert isinstance( |
|
|
dst, |
|
|
torch.nn.Module, |
|
|
), f"Expected {dst} to be a nn.Module" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dst._parameters = src._parameters |
|
|
dst._buffers = src._buffers |
|
|
|
|
|
|
|
|
def sync_state(*wrapped_method_modules): |
|
|
""" |
|
|
Sync state between exported modules corresponding to wrapped methods. |
|
|
This might be necessary after serializing/deserializing due to copying. |
|
|
""" |
|
|
if wrapped_method_modules: |
|
|
m, *other_ms = wrapped_method_modules |
|
|
for other_m in other_ms: |
|
|
_sync_state(m, other_m) |
|
|
|
|
|
|
|
|
class _WrappedMethod(torch.nn.Module): |
|
|
def __init__(self, method): |
|
|
super().__init__() |
|
|
|
|
|
_sync_state(method.__self__, self) |
|
|
|
|
|
self.forward = method |
|
|
|
|
|
|
|
|
def wrap_method(method): |
|
|
""" |
|
|
Wrap a method as a module so that it can be exported. |
|
|
The wrapped module's forward points to the method, and |
|
|
the method's original module state is shared. |
|
|
""" |
|
|
assert ismethod( |
|
|
method, |
|
|
), f"Expected {method} to be a method" |
|
|
return _WrappedMethod(method) |
|
|
|