| |
| import abc |
| import copy |
| import logging |
| import operator |
| import re |
| from collections import defaultdict |
| from contextlib import contextmanager |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import Any, Callable, cast, Optional, Union |
|
|
| import torch |
| import torch.fx._pytree as fx_pytree |
| import torch.utils._pytree as pytree |
| from torch._library.fake_class_registry import FakeScriptObject |
| from torch.export import ExportedProgram |
| from torch.export._tree_utils import reorder_kwargs |
| from torch.export.exported_program import ( |
| ConstantArgument, |
| ExportGraphSignature, |
| InputKind, |
| ModuleCallSignature, |
| SymBoolArgument, |
| SymFloatArgument, |
| SymIntArgument, |
| TensorArgument, |
| ) |
| from torch.fx._symbolic_trace import is_fx_symbolic_tracing |
| from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable |
| from torch.utils._pytree import GetAttrKey, SequenceKey |
|
|
| from ._remove_effect_tokens_pass import _remove_effect_tokens |
|
|
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| __all__ = [ |
| "FlatArgsAdapter", |
| "InterpreterModule", |
| "InterpreterModuleDispatcher", |
| "UnflattenedModule", |
| "unflatten", |
| ] |
|
|
|
|
| class _AttrKind(Enum): |
| PARAMETER = "parameter" |
| BUFFER = "buffer" |
| CONSTANT = "constant" |
| MODULE = "module" |
|
|
|
|
| @dataclass(frozen=True) |
| class _TensorID: |
| """Custom tensor identifier containing storage, stride, and size information.""" |
|
|
| untyped_storage: torch.UntypedStorage |
| stride: tuple |
| size: tuple |
| storage_offset: int |
|
|
|
|
| RUN_WITH_INTERPRETER = True |
|
|
|
|
| @contextmanager |
| def _disable_interpreter(): |
| global RUN_WITH_INTERPRETER |
| old_flag = RUN_WITH_INTERPRETER |
| RUN_WITH_INTERPRETER = False |
| try: |
| yield |
| finally: |
| RUN_WITH_INTERPRETER = old_flag |
|
|
|
|
| |
| |
| def _assign_attr( |
| from_obj: Union[torch.Tensor, torch.ScriptObject, torch.nn.Module], |
| to_module: torch.nn.Module, |
| target: str, |
| attr_kind: _AttrKind, |
| persistent: bool = True, |
| ): |
| *prefix, field = target.split(".") |
| |
| |
| |
| |
| |
| |
| to_modules = {to_module} |
| for item in prefix: |
| ts: set[torch.nn.Module] = set() |
| for to_module in to_modules: |
| if not hasattr(to_module, item): |
| setattr(to_module, item, torch.nn.Module()) |
| ts.update( |
| t_call |
| for k, t_call in to_module._modules.items() |
| if _is_call_name(k, item) |
| ) |
| to_modules = ts |
|
|
| for to_module in to_modules: |
| if attr_kind == _AttrKind.PARAMETER: |
| assert isinstance(from_obj, torch.nn.Parameter) |
| to_module.register_parameter(field, from_obj) |
| elif attr_kind == _AttrKind.BUFFER: |
| assert isinstance(from_obj, torch.Tensor) |
| to_module.register_buffer(field, from_obj, persistent=persistent) |
| elif attr_kind == _AttrKind.CONSTANT: |
| assert not isinstance(from_obj, FakeScriptObject), ( |
| "FakeScriptObject should only exist during tracing." |
| ) |
| assert isinstance( |
| from_obj, |
| ( |
| torch.Tensor, |
| torch.ScriptObject, |
| ), |
| ) |
| setattr(to_module, field, from_obj) |
| elif attr_kind == _AttrKind.MODULE: |
| assert isinstance(from_obj, torch.nn.Module) |
| setattr(to_module, field, from_obj) |
|
|
|
|
| class _SubmoduleBase: |
| _ty: Optional[str] |
|
|
| def type_name(self) -> Optional[str]: |
| """ |
| Subclass of this class - InterpreterModule, InterpreterModuleDispatcher, represents |
| corresponding model in eager model. To get this type information for those modules |
| in eager model we need to use this method. |
| """ |
| return self._ty |
|
|
|
|
| class InterpreterModule(_SubmoduleBase, torch.nn.Module): |
| """A module that uses torch.fx.Interpreter to execute instead of the usual |
| codegen that GraphModule uses. This provides better stack trace information |
| and makes it easier to debug execution. |
| """ |
|
|
| graph_module: Optional[torch.fx.GraphModule] |
|
|
| def __init__( |
| self, |
| graph: torch.fx.Graph, |
| ty: Optional[str] = None, |
| ): |
| super().__init__() |
| self.graph = graph |
| self._ty = ty |
| self.graph.owning_module = self |
| self._run_with_interpreter = RUN_WITH_INTERPRETER |
|
|
| def forward(self, *args, **kwargs): |
| assert self.graph_module is not None, "Didn't finalize this InterpreterModule" |
| if not is_fx_symbolic_tracing() and ( |
| torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter |
| ): |
| |
| |
| |
| |
| return type(self.graph_module).forward(self, *args, **kwargs) |
| else: |
| if kwargs: |
| |
| |
| |
| |
| arg_list = list(args) |
| kwarg_names = self.arg_names[len(arg_list) :] |
| arg_list.extend( |
| kwargs[kwarg_name] |
| for kwarg_name in kwarg_names |
| if kwarg_name in kwargs |
| ) |
|
|
| |
| |
| |
| assert len(kwarg_names) == len(kwargs) |
| assert len(arg_list) == len(self.arg_names) |
| args = tuple(arg_list) |
|
|
| return torch.fx.Interpreter(self, graph=self.graph).run( |
| *args, enable_io_processing=False |
| ) |
|
|
| def finalize(self): |
| |
| |
| |
| |
|
|
| |
| |
| self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) |
| self.graph.lint() |
|
|
| |
| self.arg_names = [] |
| for node in self.graph.nodes: |
| if node.op == "placeholder": |
| self.arg_names.append(node.target) |
|
|
| def print_readable( |
| self, |
| print_output=True, |
| include_stride=False, |
| include_device=False, |
| colored=False, |
| ): |
| return _print_readable( |
| self, |
| "InterpreterModule", |
| print_output, |
| include_stride, |
| include_device, |
| colored, |
| ) |
|
|
|
|
| class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module): |
| """ |
| A module that carries a sequence of InterpreterModules corresponding to |
| a sequence of calls of that module. Each call to the module dispatches |
| to the next InterpreterModule, and wraps back around after the last. |
| """ |
|
|
| def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]): |
| super().__init__() |
| assert call_modules |
| self._modules = call_modules[0]._modules |
| for accessor in attrs: |
| setattr(self, accessor, getattr(call_modules[0], accessor)) |
| self._ty = call_modules[0]._ty |
| self._call_modules = call_modules |
| self._num_calls = 0 |
|
|
| def forward(self, *args, **kwargs): |
| call_module = self._call_modules[self._num_calls] |
| self._num_calls = (self._num_calls + 1) % len(self._call_modules) |
| try: |
| return call_module(*args, **kwargs) |
| except Exception: |
| self._num_calls = 0 |
| raise |
|
|
| def call_modules(self): |
| return self._call_modules |
|
|
| def print_readable( |
| self, |
| print_output=True, |
| include_stride=False, |
| include_device=False, |
| colored=False, |
| ): |
| outputs = [ |
| mod.print_readable( |
| print_output, |
| include_stride, |
| include_device, |
| colored, |
| ) |
| for mod in self._call_modules |
| ] |
| return "\n".join(outputs) |
|
|
|
|
| class FlatArgsAdapter(abc.ABC): |
| """ |
| Adapts input arguments with ``input_spec`` to align ``target_spec``. |
| """ |
|
|
| @abc.abstractmethod |
| def adapt( |
| self, |
| target_spec: pytree.TreeSpec, |
| input_spec: pytree.TreeSpec, |
| input_args: list[Any], |
| metadata: Optional[dict[str, Any]] = None, |
| obj: Optional[Any] = None, |
| ) -> list[Any]: |
| """NOTE: This adapter may mutate given ``input_args_with_path``.""" |
| ... |
|
|
| def get_flat_arg_paths(self) -> list[str]: |
| """Returns a list of paths that are used to access the flat args.""" |
| return [] |
|
|
|
|
| class UnflattenedModule(torch.nn.Module): |
| def __init__( |
| self, |
| export_module: ExportedProgram, |
| flat_args_adapter: Optional[FlatArgsAdapter] = None, |
| ): |
| super().__init__() |
| if export_module.graph_signature.backward_signature is not None: |
| raise ValueError("Unflattening on JointExportModule NYI") |
|
|
| def _id(obj): |
| """Returns _TensorID dataclass for tensors, otherwise id().""" |
| if isinstance(obj, torch.Tensor): |
| return _TensorID( |
| untyped_storage=obj.untyped_storage(), |
| stride=obj.stride(), |
| size=obj.size(), |
| storage_offset=obj.storage_offset(), |
| ) |
| return id(obj) |
|
|
| fqn_list = [entry.fqn for entry in export_module.module_call_graph] |
| assert fqn_list[0] == "" |
| export_graph = deepcopy(export_module.graph) |
| self.graph_signature = deepcopy(export_module.graph_signature) |
| self.graph = torch.fx.Graph() |
| self.graph.owning_module = self |
| self.module_call_graph = deepcopy(export_module.module_call_graph) |
| self.flat_args_adapter = flat_args_adapter |
|
|
| self.meta = export_module.graph_module.meta |
| self.meta["unflattened_module"] = self |
|
|
| |
| self.adapted = False |
| self._run_with_interpreter = RUN_WITH_INTERPRETER |
|
|
| _inplace_buffer_and_input_mutations(export_graph, self.graph_signature) |
| _fix_nn_module_stacks(export_graph) |
|
|
| self.ivals = _IVals() |
| |
| seen_modules, seen_attrs = _outline_submodules(export_graph, self) |
| |
| |
| self.ivals.update(seen_modules.values()) |
| |
| |
| _copy_graph_attrs(export_module._graph_module, self, seen_attrs) |
|
|
| self.range_constraints = export_module.range_constraints |
| self.equality_constraints: list = [] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| state_dict = export_module.state_dict |
| assigned_params: set[str] = set() |
| id_to_param: dict[ |
| Union[int, _TensorID], torch.nn.Parameter |
| ] = {} |
| for name in self.graph_signature.parameters: |
| param = state_dict[name] |
| if _id(param) not in id_to_param: |
| id_to_param[_id(param)] = torch.nn.Parameter( |
| param.clone(), requires_grad=param.requires_grad |
| ) |
|
|
| _assign_attr( |
| id_to_param[_id(param)], |
| self, |
| name, |
| attr_kind=_AttrKind.PARAMETER, |
| ) |
| assigned_params.add(name) |
|
|
| non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) |
| assigned_buffers: set[str] = set() |
| id_to_buffer: dict[Union[int, _TensorID], tuple[torch.nn.Parameter, bool]] = {} |
| for name in self.graph_signature.buffers: |
| if name in non_persistent_buffers: |
| persistent = False |
| buffer = export_module.constants[name] |
| else: |
| persistent = True |
| buffer = state_dict[name] |
|
|
| if _id(buffer) not in id_to_buffer: |
| id_to_buffer[_id(buffer)] = (buffer.clone(), persistent) |
|
|
| _assign_attr( |
| id_to_buffer[_id(buffer)][0], |
| self, |
| name, |
| attr_kind=_AttrKind.BUFFER, |
| persistent=persistent, |
| ) |
| assigned_buffers.add(name) |
|
|
| |
| |
| for name, tensor in state_dict.items(): |
| if name in assigned_params or name in assigned_buffers: |
| continue |
|
|
| is_buffer = False |
| if _id(tensor) in id_to_buffer or not isinstance( |
| tensor, torch.nn.Parameter |
| ): |
| is_buffer = True |
|
|
| if is_buffer: |
| if ( |
| _id(tensor) not in id_to_buffer |
| ): |
| id_to_buffer[_id(tensor)] = ( |
| tensor, |
| True, |
| ) |
| _assign_attr( |
| id_to_buffer[_id(tensor)][0], |
| self, |
| name, |
| attr_kind=_AttrKind.BUFFER, |
| persistent=True, |
| ) |
| else: |
| if _id(tensor) not in id_to_param: |
| id_to_param[_id(tensor)] = tensor |
| _assign_attr( |
| id_to_param[_id(tensor)], |
| self, |
| name, |
| attr_kind=_AttrKind.PARAMETER, |
| ) |
|
|
| |
| id_to_const: dict[ |
| Union[int, _TensorID], Union[torch.Tensor, torch._C.ScriptObject] |
| ] = {} |
| for fqn, constant in export_module.constants.items(): |
| if _id(constant) not in id_to_const: |
| if isinstance(constant, torch.Tensor): |
| constant = constant.clone() |
| id_to_const[_id(constant)] = constant |
| _constant = id_to_const[_id(constant)] |
| _assign_attr( |
| _constant, |
| self, |
| fqn, |
| attr_kind=_AttrKind.CONSTANT, |
| ) |
|
|
| |
| |
| consts_map: dict[Union[int, _TensorID], list[tuple[str, str]]] = defaultdict( |
| list |
| ) |
| consts_targets: set[str] = set() |
|
|
| def add_to_consts_map(obj_id, node_name, target_name): |
| name_list = consts_map[obj_id] |
| name_list.append((node_name, target_name)) |
|
|
| |
| |
| added_params_buffers: set[str] = set() |
| for s in self.graph_signature.input_specs: |
| if s.kind == InputKind.PARAMETER or ( |
| s.kind == InputKind.BUFFER and s.persistent |
| ): |
| assert hasattr(s.arg, "name") |
| assert isinstance(s.target, str) |
| add_to_consts_map( |
| _id(export_module.state_dict[s.target]), |
| s.arg.name, |
| s.target, |
| ) |
| consts_targets.add(s.target) |
| added_params_buffers.add(s.target) |
| elif ( |
| s.kind == InputKind.BUFFER |
| and not s.persistent |
| or s.kind == InputKind.CONSTANT_TENSOR |
| or s.kind == InputKind.CUSTOM_OBJ |
| ): |
| assert hasattr(s.arg, "name") |
| assert isinstance(s.target, str) |
| add_to_consts_map( |
| _id(export_module.constants[s.target]), |
| s.arg.name, |
| s.target, |
| ) |
| consts_targets.add(s.target) |
|
|
| |
| for const_name, const in export_module.constants.items(): |
| if const_name not in consts_targets: |
| const_id = _id(const) |
| assert const_id in consts_map |
| ph_name, _ = consts_map[const_id][0] |
| add_to_consts_map(const_id, ph_name, const_name) |
| added_params_buffers.add(s.target) |
|
|
| |
| for fqn, tensor in export_module.state_dict.items(): |
| if fqn not in added_params_buffers: |
| tensor_id = _id(tensor) |
| if tensor_id not in consts_map: |
| |
| |
| |
| continue |
| ph_name, _ = consts_map[tensor_id][0] |
| add_to_consts_map(tensor_id, ph_name, fqn) |
|
|
| |
| inputs_to_state: dict[str, list[str]] = {} |
| for node_target in consts_map.values(): |
| targets = [t[1] for t in node_target] |
| for n, _ in node_target: |
| inputs_to_state[n] = targets |
|
|
| _sink_params(self, inputs_to_state, []) |
|
|
| redirected_call_indices = _deduplicate_modules(seen_modules.values()) |
| fqn_list = [fqn for fqn in fqn_list if fqn not in redirected_call_indices] |
|
|
| self._dispatch_modules(redirected_call_indices, consts_targets) |
| fqn_list = [fqn for fqn in fqn_list if "@" not in fqn] |
|
|
| |
| |
| |
| self.input_placeholders = [ |
| node for node in self.graph.nodes if node.op == "placeholder" |
| ] |
| self.check_input_constraints = True |
| |
| fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)} |
| |
| for name, _ in self.named_modules(remove_duplicate=False): |
| if name not in fqn_order: |
| fqn_order[name] = len(fqn_order) |
| _reorder_submodules(self, fqn_order) |
| self.graph.lint() |
| self.finalize() |
|
|
| def _print_graph(self): |
| for fqn, mod in self.named_modules(): |
| print(fqn + ":") |
| if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph): |
| print(mod.graph) |
|
|
| def _adapt_flat_args(self, flat_args, in_spec, input): |
| signature = self.module_call_graph[0].signature |
| if in_spec == signature.in_spec: |
| return flat_args |
|
|
| if self.flat_args_adapter is None: |
| raise TypeError( |
| "There is no flat args adapter specified. " |
| "Are you sure you are calling this with the right arguments? " |
| ) |
| else: |
| flat_args = self.flat_args_adapter.adapt( |
| target_spec=signature.in_spec, |
| input_spec=in_spec, |
| input_args=flat_args, |
| metadata=self.meta, |
| obj=input, |
| ) |
|
|
| if len(flat_args) != signature.in_spec.num_leaves: |
| raise TypeError( |
| f"Flat args adaption failed, number of args mismatch " |
| f"Adatped: {len(flat_args)} \n" |
| f"Exported module: {signature.in_spec.num_leaves}" |
| ) |
| return flat_args |
|
|
| def process_forward_inputs(self, *args, **kwargs): |
| signature = self.module_call_graph[0].signature |
|
|
| reordered_kwargs = kwargs |
| if kwargs: |
| reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec) |
|
|
| flat_args_with_path, in_spec = pytree.tree_flatten_with_path( |
| (args, reordered_kwargs) |
| ) |
| flat_args = [x[1] for x in flat_args_with_path] |
|
|
| if is_fx_symbolic_tracing(): |
| return flat_args |
|
|
| if in_spec != signature.in_spec: |
| if not self.adapted: |
| print( |
| "Input treespec does not match with exported module's: \n" |
| f"Input treespec: {in_spec}. ", |
| f"Exported module treespec: {signature.in_spec}", |
| ) |
| print("Adapting flat arg to match exported module's treespec") |
| flat_args = self._adapt_flat_args(flat_args, in_spec, args) |
| self.adapted = True |
|
|
| if self.check_input_constraints: |
| |
| |
| from torch._export.utils import _check_input_constraints_for_graph |
|
|
| if self.adapted is True: |
| flat_arg_paths = ( |
| self.flat_args_adapter.get_flat_arg_paths() |
| if self.flat_args_adapter |
| else [] |
| ) |
| assert not flat_arg_paths or len(flat_arg_paths) == len(flat_args) |
| new_flat_args_with_path = [ |
| ( |
| ( |
| SequenceKey(idx=idx), |
| GetAttrKey( |
| name=flat_arg_paths[idx] |
| if flat_arg_paths |
| else "<unknown location>" |
| ), |
| ), |
| arg, |
| ) |
| for idx, arg in enumerate(flat_args) |
| ] |
| else: |
| new_flat_args_with_path = flat_args_with_path |
|
|
| _check_input_constraints_for_graph( |
| self.input_placeholders, new_flat_args_with_path, self.range_constraints |
| ) |
|
|
| return flat_args |
|
|
| def forward(self, *args, **kwargs): |
| flat_args = self.process_forward_inputs(*args, **kwargs) |
| signature = self.module_call_graph[0].signature |
|
|
| if is_fx_symbolic_tracing(): |
| return_val = torch.fx.Interpreter(self, graph=self.graph).run( |
| *flat_args, enable_io_processing=False |
| ) |
| |
| if isinstance(return_val, tuple) and len(return_val) == 1: |
| return return_val[0] |
| return return_val |
|
|
| if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter: |
| tree_out = type(self.graph_module).forward(self, *flat_args) |
| else: |
| tree_out = torch.fx.Interpreter(self, graph=self.graph).run( |
| *flat_args, enable_io_processing=False |
| ) |
| return pytree.tree_unflatten(tree_out, signature.out_spec) |
|
|
| def finalize(self): |
| self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) |
| self.graph.lint() |
|
|
| def _dispatch_modules(self, redirected_call_indices, consts_targets): |
| """For a module whose call signatures are preserved, replace |
| multiple modules corresponding to multiple calls to that module |
| with a single dispatcher module that tracks which module to call. |
| """ |
|
|
| |
| |
| called_modules = defaultdict(list) |
| for entry in self.module_call_graph: |
| if entry.fqn and entry.signature: |
| |
| |
| fqn = entry.fqn |
| mod = _get_attr(self, redirected_call_indices.get(fqn, fqn)) |
| base, idx = fqn.split("@") if "@" in fqn else [fqn, "0"] |
| called_modules[base].append((int(idx), mod)) |
|
|
| attrs_map = defaultdict(set) |
| for target in consts_targets: |
| if "." in target: |
| orig_fqn, name = target.rsplit(".", 1) |
| attrs_map[orig_fqn].add(name) |
| else: |
| attrs_map[""].add(target) |
|
|
| |
| for orig_fqn, indexed_call_modules in called_modules.items(): |
| call_modules = [mod for _, mod in sorted(indexed_call_modules)] |
| if len(call_modules) > 1: |
| for i in range(len(call_modules)): |
| fqn = _call_name(orig_fqn, i + 1) |
| if fqn not in redirected_call_indices: |
| *prefix, name = fqn.split(".") |
| _get_attr_via_attr_list(self, prefix)._modules.pop(name) |
| self.set_submodule( |
| orig_fqn, |
| InterpreterModuleDispatcher(attrs_map[orig_fqn], call_modules), |
| ) |
|
|
| |
| |
| def elide_call_indices(prefix, graph): |
| for node in graph.nodes: |
| if node.op == "call_module": |
| fqn = node.target.split("@")[0] |
| path = f"{prefix}.{fqn}" if prefix else fqn |
| if path in called_modules: |
| node.target = fqn |
|
|
| for fqn, mod in self.named_modules(remove_duplicate=False): |
| if hasattr(mod, "graph"): |
| elide_call_indices(fqn, mod.graph) |
| elif hasattr(mod, "_call_modules"): |
| for mod_ in mod._call_modules: |
| assert hasattr(mod_, "graph") |
| elide_call_indices(fqn, mod_.graph) |
|
|
| def print_readable( |
| self, |
| print_output=True, |
| include_stride=False, |
| include_device=False, |
| colored=False, |
| ): |
| return _print_readable( |
| self, |
| "UnflattenedModule", |
| print_output, |
| include_stride, |
| include_device, |
| colored, |
| ) |
|
|
|
|
| def unflatten( |
| module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None |
| ) -> UnflattenedModule: |
| """Unflatten an ExportedProgram, producing a module with the same module |
| hierarchy as the original eager module. This can be useful if you are trying |
| to use :mod:`torch.export` with another system that expects a module |
| hierarchy instead of the flat graph that :mod:`torch.export` usually produces. |
| |
| .. note:: The args/kwargs of unflattened modules will not necessarily match |
| the eager module, so doing a module swap (e.g. :code:`self.submod = |
| new_mod`) will not necessarily work. If you need to swap a module out, you |
| need to set the :code:`preserve_module_call_signature` parameter of |
| :func:`torch.export.export`. |
| |
| Args: |
| module (ExportedProgram): The ExportedProgram to unflatten. |
| flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. |
| |
| Returns: |
| An instance of :class:`UnflattenedModule`, which has the same module |
| hierarchy as the original eager module pre-export. |
| """ |
| module = _remove_effect_tokens(module) |
| m = UnflattenedModule(module, flat_args_adapter) |
|
|
| |
| |
| m.process_forward_inputs = torch._dynamo.disable( |
| m.process_forward_inputs, |
| reason="do not trace into preprocessing the inputs", |
| recursive=True, |
| ) |
|
|
| return m |
|
|
|
|
| def _inplace_buffer_and_input_mutations( |
| graph: torch.fx.Graph, |
| graph_signature: ExportGraphSignature, |
| ) -> None: |
| """Transform buffer and input mutations from their functionalized form |
| into copy_ nodes in the graph. |
| |
| Functionalization represents a buffer mutation by passing the buffer as |
| an input and output. For example, consider the eager code: |
| def forward(self, x): |
| self.buffer += x |
| return x * x |
| |
| This corresponds to a graph that looks like: |
| def forward(self, buffer, x): |
| mutated_buffer = aten.add(buffer, x) |
| mul = aten.mul(x, x) |
| return (mutated_buffer, mul) |
| |
| We want to inplace this into something that looks like the original |
| eager code: |
| def forward(self, buffer, x): |
| mutated_buffer = aten.add(buffer, x) |
| buffer.copy_(mutated_buffer) |
| mul = aten.mul(x, x) |
| return (mul,) |
| |
| Input mutations are handled similarly. |
| """ |
| output_node = next(iter(reversed(graph.nodes))) |
| assert output_node.op == "output" and len(output_node.args) == 1 |
| return_args = output_node.args[0] |
|
|
| input_name_to_node = { |
| node.name: node for node in graph.nodes if node.op == "placeholder" |
| } |
| mutation_name_to_input_name = {} |
|
|
| |
| buffer_fqn_to_input_name = { |
| buffer_fqn: k for k, buffer_fqn in graph_signature.inputs_to_buffers.items() |
| } |
| mutation_name_to_input_name = { |
| k: buffer_fqn_to_input_name[buffer_fqn] |
| for k, buffer_fqn in graph_signature.buffers_to_mutate.items() |
| } |
| |
| mutation_name_to_input_name.update(graph_signature.user_inputs_to_mutate) |
|
|
| num_mutations = len(mutation_name_to_input_name) |
|
|
| for mutation in return_args[:num_mutations]: |
| input_name = mutation_name_to_input_name[mutation.name] |
| input_node = input_name_to_node[input_name] |
|
|
| with graph.inserting_after(mutation): |
| |
| new_node = graph.create_node( |
| "call_function", torch.ops.aten.copy_.default, (input_node, mutation) |
| ) |
| for k, v in mutation.meta.items(): |
| new_node.meta[k] = v |
| |
| |
| mutation.replace_all_uses_with(new_node, lambda x: x is not new_node) |
|
|
| |
| |
| user_outputs = tuple(return_args[num_mutations:]) |
| output_node.args = ((user_outputs),) |
|
|
|
|
| def _fix_nn_module_stacks(graph): |
| |
| |
| |
| for node in graph.nodes: |
| if "nn_module_stack" not in node.meta: |
| continue |
|
|
| nn_module_stack = node.meta["nn_module_stack"] |
| fqns = [ |
| fqn.split("@")[0] if "@" in fqn else fqn |
| for fqn, _t in nn_module_stack.values() |
| ] |
|
|
| |
| prev_fqn, *next_fqns = fqns |
| num_valid_indices = 1 |
| for curr_fqn in next_fqns: |
| |
| if _is_prefix(prev_fqn, curr_fqn): |
| num_valid_indices += 1 |
| prev_fqn = curr_fqn |
| else: |
| |
| break |
|
|
| |
| if num_valid_indices < len(nn_module_stack): |
| log.warning( |
| "nn_module_stack fqns %s at node %s do not form a stack! dropping last %d entries", |
| fqns, |
| node, |
| len(nn_module_stack) - num_valid_indices, |
| ) |
| node.meta["nn_module_stack"] = dict( |
| list(nn_module_stack.items())[:num_valid_indices] |
| ) |
|
|
|
|
| def _is_prefix(candidate, target): |
| """Check whether `candidate` is a prefix of `target`.""" |
| return len(candidate) < len(target) and target[: len(candidate)] == candidate |
|
|
|
|
| def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: |
| if parent_fqn == "": |
| |
| return child_fqn |
|
|
| parent_split = parent_fqn.split(".") |
| child_split = child_fqn.split(".") |
|
|
| |
| if child_split[: len(parent_split)] != parent_split: |
| raise RuntimeError( |
| f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'." |
| "This is currently unsupported." |
| "Please try to make child module attach to parent module directly." |
| ) |
| return ".".join(child_split[len(parent_split) :]) |
|
|
|
|
| def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): |
| def graph_dump(graph: torch.fx.Graph) -> str: |
| ret = [] |
| nodes_idx: dict[int, int] = {} |
|
|
| def arg_dump(arg) -> str: |
| if isinstance(arg, torch.fx.Node): |
| return "%" + str(nodes_idx[id(arg)]) |
| return str(arg) |
|
|
| for i, node in enumerate(graph.nodes): |
| args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] |
| args_dump += [ |
| f"{key}={value}" |
| for key, value in pytree.tree_map(arg_dump, node.kwargs).items() |
| ] |
| target = node.target if node.op in ("call_function", "get_attr") else "" |
| ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") |
| nodes_idx[id(node)] = i |
| return "\n".join(ret) |
|
|
| assert isinstance(x.graph, torch.fx.Graph) |
| assert isinstance(y.graph, torch.fx.Graph) |
| return graph_dump(x.graph) == graph_dump(y.graph) |
|
|
|
|
| def _add_spec(gm: torch.nn.Module, spec) -> str: |
| i = 0 |
| while hasattr(gm, f"_spec_{i}"): |
| i += 1 |
| name = f"_spec_{i}" |
| setattr(gm, name, spec) |
| return name |
|
|
|
|
| def _generate_flatten(gm: torch.fx.GraphModule, node) -> torch.fx.Node: |
| flatten = gm.graph.call_function(pytree.tree_flatten, (node,)) |
| getitem_0 = gm.graph.call_function(operator.getitem, (flatten, 0)) |
| return getitem_0 |
|
|
|
|
| def _generate_flatten_spec( |
| gm: Union[torch.fx.GraphModule, InterpreterModule, UnflattenedModule], node, spec |
| ) -> torch.fx.Node: |
| name = _add_spec(gm, spec) |
| spec_node = gm.graph.get_attr(name) |
| return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) |
|
|
|
|
| def _generate_unflatten( |
| gm: Union[torch.fx.GraphModule, InterpreterModule, UnflattenedModule], nodes, spec |
| ) -> torch.fx.Node: |
| name = _add_spec(gm, spec) |
| spec_node = gm.graph.get_attr(name) |
| return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) |
|
|
|
|
| def _get_submodule(mod: torch.nn.Module, target: str): |
| *prefix, field = target.split(".") |
|
|
| for item in prefix: |
| submod = getattr(mod, item, None) |
|
|
| if submod is None: |
| return None |
|
|
| if not isinstance(submod, torch.nn.Module): |
| return None |
|
|
| mod = submod |
|
|
| return getattr(mod, field, None) |
|
|
|
|
| def _add_submodule( |
| mod: torch.nn.Module, |
| target: str, |
| module_to_add: torch.nn.Module, |
| create_module: Optional[Callable[[str], torch.nn.Module]] = None, |
| ): |
| *prefix, field = target.split(".") |
|
|
| for i, item in enumerate(prefix): |
| submod = getattr(mod, item, None) |
|
|
| if submod is None: |
| if create_module is not None: |
| submod = create_module(".".join(prefix[: i + 1])) |
| else: |
| submod = torch.nn.Module() |
| setattr(mod, item, submod) |
|
|
| if not isinstance(submod, torch.nn.Module): |
| return False |
|
|
| mod = submod |
|
|
| mod.add_module(field, module_to_add) |
|
|
|
|
| def _call_name(base: str, n: int) -> str: |
| |
| |
| return base if n == 1 else f"{base}@{n - 1}" |
|
|
|
|
| def _is_call_name(call_name: str, base: str) -> bool: |
| |
| return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None |
|
|
|
|
| class _ModuleFrame: |
| def __init__( |
| self, |
| flat_graph: torch.fx.Graph, |
| nodes: tuple[torch.fx.Node, ...], |
| seen_nodes, |
| seen_modules, |
| seen_attrs, |
| created_modules, |
| parent, |
| module_stack: list[tuple[str, Optional[str], int]], |
| module_id, |
| module_call_graph: dict[str, ModuleCallSignature], |
| module: Optional[Union[torch.fx.GraphModule, UnflattenedModule]] = None, |
| ): |
| self.flat_graph = flat_graph |
| self.nodes = nodes |
| self.seen_nodes = seen_nodes |
| self.seen_modules = seen_modules |
| self.seen_attrs = seen_attrs |
| self.created_modules = created_modules |
| self.parent = parent |
| self.module_stack = module_stack |
| self.module_id = module_id |
|
|
| self.module_call_graph = module_call_graph |
| self.verbose = False |
|
|
| self.fqn, ty, num_calls = self.module_stack[-1] |
| |
| self.child_fqn = _call_name(self.fqn, num_calls + 1) |
|
|
| self.module: Union[torch.fx.GraphModule, UnflattenedModule, InterpreterModule] |
| if module is not None: |
| self.module = module |
| self.ivals = module.ivals if hasattr(module, "ivals") else {} |
| else: |
| self.module = self.created_modules.get( |
| self.fqn, |
| InterpreterModule(torch.fx.Graph(), ty=ty), |
| ) |
| self.ivals = parent.ivals |
|
|
| self.graph = self.module.graph |
|
|
| |
| self.node_map: dict[torch.fx.Node, torch.fx.Node] = {} |
| self.node_to_placeholder = {} |
|
|
| self.parent_call_module: Optional[torch.fx.Node] = None |
| if parent is not None: |
| accessor = _compute_accessor(parent.fqn, self.child_fqn) |
|
|
| def create_module(fqn): |
| path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn |
| if path in self.created_modules: |
| return self.created_modules[path] |
| submod = InterpreterModule(torch.fx.Graph(), ty=ty) |
| self.created_modules[path] = submod |
| return submod |
|
|
| _add_submodule(parent.module, accessor, self.module, create_module) |
| self.parent_call_module = parent.graph.call_module(accessor) |
| if self.seen_modules[self.module_id]: |
| base_module_frame = self.seen_modules[self.module_id][0] |
| self.module._modules = base_module_frame.module._modules |
| self.seen_modules[self.module_id].append( |
| _SubmoduleEntry( |
| parent_fqn=self.parent.fqn, |
| parent_module=self.parent.module, |
| parent_call_module=self.parent_call_module, |
| fqn=self.fqn, |
| call_idx=num_calls + 1, |
| module=self.module, |
| ) |
| ) |
|
|
| signature = module_call_graph.get(self.child_fqn) |
| if signature is not None and self.parent is not None: |
| assert signature.in_spec.num_children == 2 |
| args_spec = signature.in_spec.children_specs[0] |
| kwargs_spec = signature.in_spec.children_specs[1] |
| assert args_spec.context is None |
| assert kwargs_spec.context is not None |
|
|
| with self.graph.inserting_after(None): |
| arg_nodes = [ |
| self.graph.placeholder(f"_positional_arg_{idx}") |
| for idx in range(args_spec.num_children) |
| ] |
| kwarg_nodes = {} |
| for name in kwargs_spec.context: |
| kwarg_nodes[name] = self.graph.placeholder(name) |
| flat_args = _generate_flatten_spec( |
| self.module, |
| (tuple(arg_nodes), kwarg_nodes), |
| signature.in_spec, |
| ) |
| for idx, arg in enumerate(signature.inputs): |
| flat_arg_node = self.graph.create_node( |
| op="call_function", |
| target=operator.getitem, |
| args=(flat_args, idx), |
| name=( |
| arg.name |
| if not isinstance(arg, ConstantArgument) |
| else f"_constant_{idx}" |
| ), |
| ) |
| if isinstance(arg, ConstantArgument): |
| continue |
|
|
| if arg.name in self.seen_nodes: |
| flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) |
| self.node_to_placeholder[self.seen_nodes[arg.name]] = ( |
| flat_arg_node |
| ) |
|
|
| with self.parent.graph.inserting_before(self.parent_call_module): |
| input_nodes: list[Optional[torch.fx.Node]] = [] |
| for input in signature.inputs: |
| if isinstance(input, ConstantArgument): |
| input_nodes.append(input.value) |
| elif input.name not in self.seen_nodes: |
| input_nodes.append(None) |
| else: |
| assert isinstance( |
| input, |
| ( |
| TensorArgument, |
| SymIntArgument, |
| SymBoolArgument, |
| SymFloatArgument, |
| ), |
| ) |
| input_nodes.append( |
| self.parent.remap_input(self.seen_nodes[input.name]) |
| ) |
|
|
| inputs_node = _generate_unflatten( |
| self.parent.module, |
| input_nodes, |
| signature.in_spec, |
| ) |
|
|
| args_node = self.parent.graph.call_function( |
| operator.getitem, (inputs_node, 0) |
| ) |
| kwargs_node = self.parent.graph.call_function( |
| operator.getitem, (inputs_node, 1) |
| ) |
| arg_nodes = [ |
| self.parent.graph.call_function(operator.getitem, (args_node, i)) |
| for i in range(args_spec.num_children) |
| ] |
| kwarg_nodes = { |
| k: self.parent.graph.call_function( |
| operator.getitem, (kwargs_node, k) |
| ) |
| for k in kwargs_spec.context |
| } |
| assert self.parent_call_module is not None |
| self.parent_call_module.args = tuple(arg_nodes) |
| self.parent_call_module.kwargs = kwarg_nodes |
|
|
| def add_placeholder(self, x): |
| assert self.fqn != "", f"Cannot add placeholder {x} to root module" |
| assert x.graph is self.flat_graph |
| |
| with self.graph.inserting_before(None): |
| placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) |
| |
| |
| placeholder_node.meta = copy.copy(x.meta) |
| self.node_to_placeholder[x] = placeholder_node |
|
|
| def copy_sym_call_function(self, x): |
| |
| |
| |
| |
| |
| |
| args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args) |
| kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs) |
| node = self.graph.call_function(x.target, args, kwargs) |
| node.meta = copy.copy(x.meta) |
| self.node_map[x] = node |
| return node |
|
|
| def remap_input(self, x): |
| assert x.graph is self.flat_graph |
| if x in self.node_map: |
| return self.node_map[x] |
| self.print(f"remap_input({x})") |
| if x in self.node_to_placeholder: |
| return self.node_to_placeholder[x] |
| elif ( |
| x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None |
| |
| ): |
| self.add_placeholder(x) |
| if self.parent_call_module is not None: |
| |
| |
| with self.parent.graph.inserting_before(self.parent_call_module): |
| self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) |
| return self.node_to_placeholder[x] |
| elif x.op == "call_function" and ( |
| x.target |
| in ( |
| torch.ops.aten.sym_size.int, |
| torch.ops.aten.item.default, |
| torch.ops.aten.unbind.int, |
| torch.ops.aten.sum.dim_IntList, |
| torch.ops.aten.view.default, |
| torch.ops.aten.diff.default, |
| ) |
| or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") |
| ): |
| |
| |
| self.copy_sym_call_function(x) |
| return self.node_map[x] |
| elif self.module_call_graph.get(self.fqn) is not None: |
| |
| |
| return self.ivals.read(self, x) |
| else: |
| raise RuntimeError( |
| f"Could not run remap_input() on op type: {x.op} for node {x}" |
| ) |
|
|
| def finalize_outputs(self): |
| self.created_modules.pop(self.fqn, None) |
|
|
| orig_outputs = [] |
|
|
| signature = self.module_call_graph.get(self.child_fqn) |
| if signature is not None and self.parent is not None: |
| for output in signature.outputs: |
| if isinstance( |
| output, |
| ( |
| TensorArgument, |
| SymIntArgument, |
| SymBoolArgument, |
| SymFloatArgument, |
| ConstantArgument, |
| ), |
| ): |
| if output.name in self.seen_nodes: |
| orig_outputs.append(self.seen_nodes[output.name]) |
| else: |
| orig_outputs.append(None) |
| else: |
| raise RuntimeError( |
| f"Unsupported data type for output node: {output}" |
| ) |
|
|
| def get_actual_output_node(output): |
| if output is None: |
| return None |
|
|
| seen_node = self.seen_nodes[output.name] |
| if seen_node in self.node_map: |
| return self.node_map[seen_node] |
| elif seen_node in self.node_to_placeholder: |
| return self.node_to_placeholder[seen_node] |
| else: |
| raise RuntimeError( |
| f"Could not find output node {output}. Graph: {self.graph}" |
| ) |
|
|
| tree_out_node = _generate_unflatten( |
| self.module, |
| tuple(get_actual_output_node(output) for output in orig_outputs), |
| signature.out_spec, |
| ) |
| parent_out: Optional[torch.fx.Node] = _generate_flatten_spec( |
| self.parent.module, self.parent_call_module, signature.out_spec |
| ) |
| graph_outputs: Union[torch.fx.Node, list[torch.fx.Node]] = tree_out_node |
| else: |
| graph_outputs = [] |
| |
| for orig_node in self.node_map.keys(): |
| for user_node in orig_node.users: |
| if user_node.name not in self.seen_nodes: |
| |
| orig_outputs.append(orig_node) |
| graph_outputs.append(self.node_map[orig_node]) |
| break |
|
|
| parent_out = self.parent_call_module |
| if len(graph_outputs) == 1: |
| graph_outputs = graph_outputs[0] |
|
|
| assert isinstance(graph_outputs, (list, torch.fx.Node)) |
|
|
| self.graph.output(graph_outputs) |
|
|
| |
| if parent_out is None: |
| return |
|
|
| parent_out.meta["val"] = ( |
| graph_outputs.meta.get("val") |
| if isinstance(graph_outputs, torch.fx.Node) |
| else [o.meta.get("val") for o in graph_outputs] |
| ) |
|
|
| if len(orig_outputs) == 1 and signature is None: |
| self.parent.node_map[orig_outputs[0]] = parent_out |
| else: |
| for i, orig_output in enumerate(orig_outputs): |
| if orig_output is None: |
| continue |
| |
| proxy_out = torch.fx.Proxy(parent_out)[i].node |
| proxy_out.meta["val"] = orig_output.meta.get("val") |
| self.parent.node_map[orig_output] = proxy_out |
|
|
| def copy_node(self, node): |
| self.print("copying", node.format_node()) |
| self.node_map[node] = self.graph.node_copy(node, self.remap_input) |
| self.seen_nodes[node.name] = node |
|
|
| def run_outer(self): |
| for i, node in enumerate(self.flat_graph.nodes): |
| self.print(i, node.meta.get("nn_module_stack"), node.format_node()) |
|
|
| |
| node_idx: int = 0 |
| node = self.nodes[node_idx] |
| while node.op == "placeholder": |
| self.copy_node(node) |
| node_idx += 1 |
| node = self.nodes[node_idx] |
|
|
| self.run_from(node_idx) |
|
|
| |
| for node in self.flat_graph.nodes: |
| if node.op == "output": |
| self.copy_node(node) |
|
|
| def print(self, *args, **kwargs): |
| if self.verbose: |
| print(*args, **kwargs) |
|
|
| def run_from(self, node_idx): |
| module_idx = 0 |
| |
| while node_idx < len(self.nodes): |
| node = self.nodes[node_idx] |
| assert node.op != "placeholder" |
|
|
| self.print() |
| self.print("STEP", node_idx, node.format_node()) |
| self.print(self.module_stack) |
| depth = len(self.module_stack) |
| if node.op == "output": |
| if depth == 1: |
| |
| |
| |
| return node_idx |
|
|
| |
| self.finalize_outputs() |
| return node_idx |
|
|
| if len(node.meta.get("nn_module_stack", {})) == 0: |
| raise RuntimeError(f"Unable to find nn_module_stack for node {node}") |
|
|
| nn_module_stack = node.meta["nn_module_stack"] |
| from torch._export.passes._node_metadata_hook import ( |
| _EMPTY_NN_MODULE_STACK_KEY, |
| ) |
|
|
| if ( |
| len(nn_module_stack) == 1 |
| and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack |
| ): |
| |
| node_module_stack = self.module_stack |
| else: |
| node_module_stack = [ |
| ( |
| path, |
| ty if path else None, |
| int(k.split("@")[-1]) if "@" in k else 0, |
| ) |
| for k, (path, ty) in node.meta["nn_module_stack"].items() |
| ] |
|
|
| if node_module_stack[:depth] != self.module_stack: |
| |
| |
| |
| |
| |
| self.finalize_outputs() |
| self.print("outlining", self.fqn) |
| self.print(self.graph) |
| return node_idx |
|
|
| assert node_module_stack is not None |
|
|
| if _is_prefix(self.module_stack, node_module_stack): |
| |
| |
| next_module = node_module_stack[depth] |
| self.print("Creating new stack frame for", next_module) |
| |
| |
| next_module_key = list(node.meta["nn_module_stack"].keys())[depth] |
| node_idx = _ModuleFrame( |
| self.flat_graph, |
| self.nodes, |
| self.seen_nodes, |
| self.seen_modules, |
| self.seen_attrs, |
| self.created_modules, |
| self, |
| self.module_stack + [next_module], |
| next_module_key.split("@")[0], |
| self.module_call_graph, |
| ).run_from(node_idx) |
| module_idx += 1 |
| continue |
|
|
| |
| |
| assert node_module_stack == self.module_stack |
|
|
| if node.op == "get_attr": |
| |
| self.seen_attrs[self.child_fqn].add(node.target) |
|
|
| self.copy_node(node) |
| node_idx += 1 |
|
|
|
|
| @dataclass |
| class _SubmoduleEntry: |
| parent_fqn: str |
| parent_module: torch.nn.Module |
| parent_call_module: torch.fx.Node |
| fqn: str |
| call_idx: int |
| module: torch.nn.Module |
|
|
|
|
| def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): |
| seen_nodes: dict[str, torch.fx.Node] = {} |
| seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) |
| seen_attrs: dict[str, set[str]] = defaultdict(set) |
| created_modules: dict[str, torch.nn.Module] = {} |
| _ModuleFrame( |
| orig_graph, |
| tuple(orig_graph.nodes), |
| seen_nodes, |
| seen_modules, |
| seen_attrs, |
| created_modules, |
| None, |
| [("", None, 0)], |
| "", |
| { |
| entry.fqn: entry.signature |
| for entry in root_module.module_call_graph |
| if entry.signature |
| }, |
| module=root_module, |
| ).run_outer() |
| return seen_modules, seen_attrs |
|
|
|
|
| def _reorder_submodules( |
| parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = "" |
| ): |
| |
| if prefix == "": |
| for fqn in list(fqn_order.keys())[1:]: |
| if _get_submodule(parent, fqn) is None: |
| _add_submodule(parent, fqn, torch.nn.Module()) |
|
|
| children = [] |
| for name, child in list(parent._modules.items()): |
| if child is None: |
| continue |
| fqn = prefix + name |
| _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".") |
| delattr(parent, name) |
| children.append((fqn_order[fqn], name, child)) |
| children.sort(key=operator.itemgetter(0)) |
| for _, name, child in children: |
| parent.register_module(name, child) |
|
|
|
|
| class _IVals: |
| """ |
| Collect the intermediate values of mutations in a graph. |
| |
| Example: in the following graph, suppose that buf_in and buf_out |
| are the input and output values of a buffer. |
| |
| buf_in = placeholder() |
| ... |
| ival1 = f0(buf_in, ...) # inside self.n0(...) |
| ... |
| ival2 = f1(ival1, ...) # inside self.n1(...) |
| ... |
| buf_out = f2(ival2, ...) # inside self.n2(...) |
| return buf_out, ... |
| |
| Here ival1 and ival2 are intermediate values created inside |
| calls to n0 and n1 respectively, and used inside calls to |
| n1 and n2 respectively. |
| """ |
|
|
| def __init__(self): |
| |
| self.node_names_by_fqn = defaultdict(set) |
|
|
| def _is_mutable(self, target): |
| if isinstance(target, torch._ops.OpOverload): |
| return target._schema.is_mutable |
| return False |
|
|
| def read(self, mf, node): |
| """ |
| Read state corresponding to a given intermediate value. |
| """ |
| |
| assert node.op == "call_function" |
| b = self._is_mutable(node.target) |
| print("Checking mutability", node.target, b) |
| if not b: |
| |
| |
| fqn, _ = next(reversed(node.meta["nn_module_stack"].values())) |
| self.node_names_by_fqn[fqn].add(node.name) |
| return mf.remap_input(node.args[0]) |
|
|
| def update(self, partitions): |
| """ |
| Update states corresponding to intermediate values that were read. |
| """ |
| for shared_submodules in partitions: |
| for entry in shared_submodules: |
| graph = entry.module.graph |
| node_names = self.node_names_by_fqn[entry.fqn] |
| nodes = [n for n in graph.nodes if n.name in node_names] |
| for node in nodes: |
| |
| |
| with graph.inserting_after(node): |
| new_node = graph.create_node( |
| "call_function", |
| torch.ops.aten.copy_.default, |
| (node.args[0], node), |
| ) |
| new_node.meta = copy.copy(node.meta) |
|
|
|
|
| def _copy_graph_attrs( |
| gm: torch.fx.GraphModule, |
| root_module: UnflattenedModule, |
| seen_attrs: dict[str, set[str]], |
| ): |
| for child_fqn, names in seen_attrs.items(): |
| module = _get_attr(root_module, child_fqn) if child_fqn else root_module |
| for name in names: |
| val = getattr(gm, name) |
| setattr(module, name, val) |
|
|
|
|
| def _deduplicate_modules(partitions): |
| redirected_call_indices = {} |
| for shared_submodules in partitions: |
| for i, entry in enumerate(shared_submodules): |
| child_fqn = _call_name(entry.fqn, entry.call_idx) |
| target = _compute_accessor(entry.parent_fqn, child_fqn) |
| deduplicated = False |
| |
| for seen in shared_submodules[:i]: |
| if _check_graph_equivalence(seen.module, entry.module): |
| parent = entry.parent_module |
| |
| |
| if seen.fqn == entry.fqn: |
| |
| |
| |
| |
| *prefix, name = target.split(".") |
| _get_attr_via_attr_list(parent, prefix)._modules.pop(name) |
| seen_child_fqn = _call_name(seen.fqn, seen.call_idx) |
| seen_target = _compute_accessor( |
| entry.parent_fqn, seen_child_fqn |
| ) |
| entry.parent_call_module.target = seen_target |
| redirected_call_indices[child_fqn] = seen_child_fqn |
| break |
| elif not deduplicated: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| parent.set_submodule(target, seen.module) |
| deduplicated = True |
|
|
| return redirected_call_indices |
|
|
|
|
| def _sink_params( |
| module: torch.nn.Module, |
| inputs_to_state: dict[str, list[str]], |
| scope: list[str], |
| module_id_to_inputs_removed: Optional[dict[int, set[str]]] = None, |
| ): |
| """Sink params, buffers, and constants from graph inputs into get_attr nodes. |
| |
| Exported modules are purely functional, so they pass their parameters and |
| buffers in as inputs to the graph. |
| |
| To replicate eager's semantics, we need to get them from the module state |
| via get_attr instead. |
| |
| module: GraphModule, potentially containing nested submodules. |
| inputs_to_state: mapping graph input names to the corresponding key in the state_dict. |
| scope: tracks where we are in the module hierarchy, so that we can emit the |
| right `getattr(self, "foo.bar")` calls, etc. |
| module_id_to_inputs_removed: records inputs removed by child modules, mapping |
| the module object id to the list of placeholder node names in the child module |
| that were removed. |
| """ |
| if module_id_to_inputs_removed is None: |
| module_id_to_inputs_removed = defaultdict(set) |
|
|
| if id(module) in module_id_to_inputs_removed: |
| return {id(module): module_id_to_inputs_removed[id(module)]} |
|
|
| |
| |
| for name, submodule in module._modules.items(): |
| submod_id_to_inputs_removed = _sink_params( |
| cast("torch.nn.Module", submodule), |
| inputs_to_state, |
| scope + [name], |
| module_id_to_inputs_removed, |
| ) |
| for k, v in submod_id_to_inputs_removed.items(): |
| module_id_to_inputs_removed[k].update(v) |
|
|
| graph = getattr(module, "graph", None) |
| if graph is None or len(graph.nodes) == 0: |
| |
| return module_id_to_inputs_removed |
|
|
| assert isinstance(graph, torch.fx.Graph) |
|
|
| inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) |
| the_last_input = None if len(inputs) == 0 else inputs[-1] |
|
|
| |
| call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) |
| for node in call_module_nodes: |
| submodule = _get_attr(module, node.target) |
| |
| |
| if submodule is not None and id(submodule) in module_id_to_inputs_removed: |
| node.args = tuple( |
| filter( |
| lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], |
| node.args, |
| ) |
| ) |
|
|
| |
| inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {} |
| for node in inputs: |
| if node.name not in inputs_to_state: |
| continue |
|
|
| state_name = None |
| for sn in inputs_to_state[node.name]: |
| sn_split = sn.split(".") |
| if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]: |
| state_name = sn_split |
| break |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if state_name is None: |
| continue |
|
|
| inputs_to_state_of_scope[node] = state_name |
|
|
| |
| inputs_removed: set[str] = set() |
|
|
| for node, state_name in inputs_to_state_of_scope.items(): |
| if len(node.users) > 0: |
| attr_path = state_name[len(scope) :] |
| state_attr = _get_attr_via_attr_list(module, attr_path) |
| assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) |
|
|
| |
| with graph.inserting_after(the_last_input): |
| new_node = graph.create_node("get_attr", ".".join(attr_path)) |
|
|
| node.replace_all_uses_with(new_node, propagate_meta=True) |
|
|
| graph.erase_node(node) |
| inputs_removed.add(node.name) |
|
|
| if isinstance(module, InterpreterModule): |
| module.finalize() |
|
|
| return {id(module): inputs_removed} |
|
|