|
|
import logging |
|
|
import operator |
|
|
import types |
|
|
from collections import defaultdict |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.fx._pytree as fx_pytree |
|
|
import torch.utils._pytree as pytree |
|
|
from torch.export.exported_program import ( |
|
|
ConstantArgument, |
|
|
ExportedProgram, |
|
|
ModuleCallSignature, |
|
|
) |
|
|
from torch.fx.passes.tools_common import legalize_graph, NodeList |
|
|
from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule |
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]: |
|
|
node_users = list(node.users.keys()) |
|
|
getitem_users = set() |
|
|
for user in node_users: |
|
|
if user.op == "output": |
|
|
continue |
|
|
|
|
|
assert user.op == "call_function" and user.target == operator.getitem, ( |
|
|
f"Expected getitem node as user for {node}, instead got {user}" |
|
|
) |
|
|
getitem_users.update(list(user.users.keys())) |
|
|
return getitem_users |
|
|
|
|
|
|
|
|
def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: |
|
|
""" |
|
|
We want to try to remove extraneous pytree flatten/unflatten calls between modules |
|
|
calls. Instead of having the following: |
|
|
graph(): |
|
|
... |
|
|
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) |
|
|
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {}) |
|
|
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) |
|
|
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {}) |
|
|
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {}) |
|
|
%getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {}) |
|
|
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {}) |
|
|
%bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {}) |
|
|
... |
|
|
|
|
|
We could do the following, if we know that all the outputs of `foo` feed into `bar`: |
|
|
graph(): |
|
|
... |
|
|
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) |
|
|
%bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {}) |
|
|
... |
|
|
|
|
|
Currently this optimization only works for the case where all of the outputs |
|
|
of `foo` go directly into `bar`, and `bar` has no other inputs. |
|
|
""" |
|
|
|
|
|
log.debug("Trying to remove pytrees for module call %s", curr_module_node) |
|
|
|
|
|
curr_module_users = list(curr_module_node.users.keys()) |
|
|
assert len(curr_module_users) == 1, ( |
|
|
f"Expected only one user for module node, instead got {list(curr_module_users)}" |
|
|
) |
|
|
flatten_node = curr_module_users[0] |
|
|
assert ( |
|
|
flatten_node.op == "call_function" |
|
|
and flatten_node.target == fx_pytree.tree_flatten_spec |
|
|
) |
|
|
|
|
|
flatten_getitem_users = _get_getitem_users(flatten_node) |
|
|
if len(flatten_getitem_users) != 1: |
|
|
log.debug( |
|
|
"More than one user found for flatten node, %s: %s. " |
|
|
"Unable to fuse it with another unflatten call.", |
|
|
flatten_node, |
|
|
flatten_getitem_users, |
|
|
) |
|
|
return |
|
|
|
|
|
unflatten_node = next(iter(flatten_getitem_users)) |
|
|
if not ( |
|
|
unflatten_node.op == "call_function" |
|
|
and unflatten_node.target == pytree.tree_unflatten |
|
|
): |
|
|
log.debug( |
|
|
"Flatten node %s's user is not a pytree.tree_unflatten. " |
|
|
"Instead it is: %s. Passing...", |
|
|
flatten_node, |
|
|
unflatten_node, |
|
|
) |
|
|
return |
|
|
|
|
|
for i, arg in enumerate(unflatten_node.args[0]): |
|
|
if arg not in flatten_node.users: |
|
|
log.debug( |
|
|
"Module %s's outputs are not all directly used as inputs to " |
|
|
"the subsequent module. Unable to fuse the connecting " |
|
|
"flatten/unflatten. The inputs to the subsequent module are: %s. ", |
|
|
curr_module_node, |
|
|
unflatten_node.args[0], |
|
|
) |
|
|
return |
|
|
|
|
|
if not ( |
|
|
arg.op == "call_function" |
|
|
and arg.target == operator.getitem |
|
|
and arg.args[1] == i |
|
|
): |
|
|
log.debug( |
|
|
"Module %s's outputs are not all directly used in the same " |
|
|
"order as outputted. Unable to fuse the connecting " |
|
|
"flatten/unflatten. The inputs to the " |
|
|
"subsequent module are: %s. ", |
|
|
curr_module_node, |
|
|
unflatten_node.args[0], |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
unflatten_getitem_getitem_users = set() |
|
|
unflatten_getitem_users = _get_getitem_users(unflatten_node) |
|
|
for unflatten_getitem_user in unflatten_getitem_users: |
|
|
unflatten_getitem_getitem_users.update( |
|
|
list(unflatten_getitem_user.users.keys()) |
|
|
) |
|
|
|
|
|
if len(unflatten_getitem_getitem_users) != 1: |
|
|
log.debug( |
|
|
"More than one user found for unflatten node, %s: %s. " |
|
|
"Unable to fuse it with another flatten call.", |
|
|
unflatten_node, |
|
|
unflatten_getitem_getitem_users, |
|
|
) |
|
|
return |
|
|
|
|
|
next_module_node = next(iter(unflatten_getitem_getitem_users)) |
|
|
if not (next_module_node.op == "call_module"): |
|
|
log.debug( |
|
|
"Unflatten node %s's user is not a call_module. " |
|
|
"Instead it is: %s. Passing...", |
|
|
unflatten_node, |
|
|
next_module_node, |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
next_module_node.args = (curr_module_node,) |
|
|
|
|
|
|
|
|
def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None: |
|
|
""" |
|
|
Remove extraneous pytree flatten/unflatten calls. |
|
|
|
|
|
We try a couple of optimizations here: |
|
|
1. Remove pytree flatten/unflatten calls between modules |
|
|
2. TODO: Remove module's in_spec + initial unflatten call |
|
|
3. TODO: Remove module's out_spec + final flatten call |
|
|
""" |
|
|
|
|
|
for node in gm.graph.nodes: |
|
|
if node.op == "call_module" and node.target != "_guards_fn": |
|
|
_try_remove_connecting_pytrees(node) |
|
|
|
|
|
gm.graph.eliminate_dead_code() |
|
|
|
|
|
|
|
|
def _construct_inputs( |
|
|
gm: torch.fx.GraphModule, |
|
|
signature: ModuleCallSignature, |
|
|
node_name_map: dict[str, torch.fx.Node], |
|
|
) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]: |
|
|
tree_unflatten_args: list[Optional[torch.fx.Node]] = [] |
|
|
for input_ in signature.inputs: |
|
|
if isinstance(input_, ConstantArgument) and input_.value is None: |
|
|
|
|
|
|
|
|
tree_unflatten_args.append(None) |
|
|
elif input_.name not in node_name_map: |
|
|
|
|
|
tree_unflatten_args.append(None) |
|
|
else: |
|
|
tree_unflatten_args.append(node_name_map[input_.name]) |
|
|
|
|
|
|
|
|
from .unflatten import _generate_unflatten |
|
|
|
|
|
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec) |
|
|
|
|
|
assert signature.in_spec.num_children == 2 |
|
|
|
|
|
args_spec = signature.in_spec.children_specs[0] |
|
|
assert args_spec.context is None |
|
|
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0)) |
|
|
args_nodes = [ |
|
|
gm.graph.call_function(operator.getitem, (args_node, i)) |
|
|
for i in range(args_spec.num_children) |
|
|
] |
|
|
|
|
|
kwargs_spec = signature.in_spec.children_specs[1] |
|
|
assert kwargs_spec.context is not None |
|
|
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1)) |
|
|
kwargs_nodes = { |
|
|
k: gm.graph.call_function(operator.getitem, (kwargs_node, k)) |
|
|
for k in kwargs_spec.context |
|
|
} |
|
|
return args_nodes, kwargs_nodes |
|
|
|
|
|
|
|
|
def _insert_call_module( |
|
|
gm: torch.fx.GraphModule, |
|
|
args_nodes: list[torch.fx.Node], |
|
|
kwargs_nodes: dict[str, torch.fx.Node], |
|
|
module_to_swap: torch.nn.Module, |
|
|
name: str, |
|
|
) -> torch.fx.Node: |
|
|
from .unflatten import _assign_attr, _AttrKind |
|
|
|
|
|
_assign_attr(module_to_swap, gm, name, _AttrKind.MODULE) |
|
|
module_node = gm.graph.call_module(name, tuple(args_nodes), kwargs_nodes) |
|
|
return module_node |
|
|
|
|
|
|
|
|
def _deconstruct_outputs( |
|
|
gm: torch.fx.GraphModule, |
|
|
signature: ModuleCallSignature, |
|
|
module_node: torch.fx.Node, |
|
|
node_name_map: dict[str, torch.fx.Node], |
|
|
orig_outputs: tuple[torch.fx.Node, ...], |
|
|
) -> None: |
|
|
from .unflatten import _generate_flatten_spec |
|
|
|
|
|
flatten_node = _generate_flatten_spec(gm, module_node, signature.out_spec) |
|
|
|
|
|
for i, orig_output in enumerate(orig_outputs): |
|
|
|
|
|
proxy_out = torch.fx.Proxy(flatten_node)[i].node |
|
|
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) |
|
|
|
|
|
node_name_map[orig_output.name] = proxy_out |
|
|
|
|
|
|
|
|
def _swap_module_helper( |
|
|
gm: torch.fx.GraphModule, |
|
|
modules_to_swap: dict[str, torch.nn.Module], |
|
|
module_call_graph: dict[str, ModuleCallSignature], |
|
|
) -> torch.fx.GraphModule: |
|
|
log.debug("Starting graph:") |
|
|
log.debug(gm.graph) |
|
|
|
|
|
legalize_graph(gm) |
|
|
|
|
|
partitions: dict[str, NodeList] = defaultdict(list) |
|
|
|
|
|
node_name_map: dict[str, torch.fx.Node] = { |
|
|
node.name: node for node in gm.graph.nodes |
|
|
} |
|
|
|
|
|
|
|
|
for node in gm.graph.nodes: |
|
|
if nn_module_stack := node.meta.get("nn_module_stack"): |
|
|
for path, _ in nn_module_stack.values(): |
|
|
if path in modules_to_swap: |
|
|
partitions[path].append(node) |
|
|
break |
|
|
|
|
|
for name, nodes in partitions.items(): |
|
|
""" |
|
|
Given a graph like the following, and we want to swap out the submodule "foo": |
|
|
graph(): |
|
|
%x : [num_users=1] = placeholder[target=x] |
|
|
%y : [num_users=2] = placeholder[target=y] |
|
|
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}), nn_module_stack = {"foo": ("foo", torch.nn.Module)} |
|
|
%sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %add), kwargs = {}), nn_module_stack = {"bar": ("bar", torch.nn.Module)} |
|
|
return (sub,) |
|
|
|
|
|
We will first partition out foo's subgraph: |
|
|
graph(): |
|
|
%x : [num_users=1] = placeholder[target=x] |
|
|
%y : [num_users=2] = placeholder[target=y] |
|
|
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}) |
|
|
return add |
|
|
|
|
|
And then insert an unflatten + call_module + flatten to replace the subgraph: |
|
|
graph(): |
|
|
%x : [num_users=1] = placeholder[target=x] |
|
|
%y : [num_users=1] = placeholder[target=y] |
|
|
|
|
|
%_spec_0 : [num_users=1] = get_attr[target=_spec_0] |
|
|
%tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {}) |
|
|
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {}) |
|
|
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {}) |
|
|
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {}) |
|
|
%getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten, 1), kwargs = {}) |
|
|
%foo : [num_users=0] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) |
|
|
%_spec_1 : [num_users=1] = get_attr[target=_spec_1] |
|
|
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (None, %_spec_1), kwargs = {}) |
|
|
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) |
|
|
|
|
|
%sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %getitem_4), kwargs = {}) |
|
|
return (%sub,) |
|
|
|
|
|
The `tree_unflatten` call will construct tensor inputs into the input |
|
|
format needed by the swapped eager module. |
|
|
The `call_module` node should now reference the swapped torch.nn.Module. |
|
|
The `tree_flatten_spec` call will deconstruct the eager outputs of the |
|
|
swapped module into tensors. |
|
|
""" |
|
|
|
|
|
submod_name = name.replace(".", "_") |
|
|
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( |
|
|
gm, nodes, f"fused_{submod_name}" |
|
|
) |
|
|
|
|
|
log.debug("Fused subgraph nodes:") |
|
|
log.debug(sub_gm.graph) |
|
|
|
|
|
signature: ModuleCallSignature = module_call_graph[name] |
|
|
|
|
|
args_nodes, kwargs_nodes = _construct_inputs(gm, signature, node_name_map) |
|
|
module_node = _insert_call_module( |
|
|
gm, args_nodes, kwargs_nodes, modules_to_swap[name], name |
|
|
) |
|
|
_deconstruct_outputs(gm, signature, module_node, node_name_map, orig_outputs) |
|
|
|
|
|
erase_nodes(gm, nodes) |
|
|
|
|
|
log.debug("Swapped graph:") |
|
|
log.debug(gm.graph) |
|
|
|
|
|
legalize_graph(gm) |
|
|
|
|
|
log.debug("Before removing extraneous pytrees:") |
|
|
log.debug(gm.graph) |
|
|
|
|
|
_remove_extraneous_pytrees(gm) |
|
|
log.debug("After removing extraneous pytrees:") |
|
|
log.debug(gm.graph) |
|
|
|
|
|
gm.recompile() |
|
|
|
|
|
return gm |
|
|
|
|
|
|
|
|
def _fix_input_output_signature( |
|
|
gm: torch.fx.GraphModule, signature: ModuleCallSignature |
|
|
) -> None: |
|
|
""" |
|
|
Given the unlifted module from calling ep.module(), we want to remove the |
|
|
pytree processing from the graph module's PyTreeCodeGen and instead make it |
|
|
nodes inside of the graph. This allows us to do some optimizations, like |
|
|
remove these pytree calls if it is unnecessary, and makes the PyTree part |
|
|
more obvious to graph passes. |
|
|
""" |
|
|
from torch.export.unflatten import _generate_flatten, _generate_unflatten |
|
|
|
|
|
|
|
|
|
|
|
gm.graph._codegen = torch.fx.graph.CodeGen() |
|
|
|
|
|
old_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] |
|
|
|
|
|
new_placeholders = [] |
|
|
forward_arg_names = signature.forward_arg_names |
|
|
if forward_arg_names is None: |
|
|
forward_arg_names = [] |
|
|
assert signature.in_spec.num_children == 2 |
|
|
arg_spec = signature.in_spec.children_specs[0] |
|
|
kwarg_spec = signature.in_spec.children_specs[1] |
|
|
assert arg_spec.type == tuple |
|
|
assert kwarg_spec.type == dict |
|
|
for i in range(arg_spec.num_children): |
|
|
forward_arg_names.append(f"arg_{i}") |
|
|
forward_arg_names.extend(kwarg_spec.context) |
|
|
|
|
|
for arg in forward_arg_names: |
|
|
with gm.graph.inserting_before(old_placeholders[0]): |
|
|
new_placeholders.append(gm.graph.placeholder(arg)) |
|
|
|
|
|
|
|
|
with gm.graph.inserting_before(old_placeholders[0]): |
|
|
flat_node = _generate_flatten(gm, tuple(new_placeholders)) |
|
|
for i, old_placeholder in enumerate(old_placeholders): |
|
|
old_placeholder.op = "call_function" |
|
|
old_placeholder.target = operator.getitem |
|
|
old_placeholder.args = (flat_node, i) |
|
|
|
|
|
|
|
|
output_node = next(node for node in gm.graph.nodes if node.op == "output") |
|
|
with gm.graph.inserting_before(output_node): |
|
|
unflat = _generate_unflatten(gm, output_node.args[0], signature.out_spec) |
|
|
output_node.args = (unflat,) |
|
|
|
|
|
gm.recompile() |
|
|
|
|
|
|
|
|
def _swap_modules( |
|
|
ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module] |
|
|
) -> torch.fx.GraphModule: |
|
|
""" |
|
|
Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps |
|
|
previously traced modules with new eager modules specified. Returns a |
|
|
fx.GraphModule with a custom forward function. |
|
|
|
|
|
Args: |
|
|
ep (ExportedProgram): Exported program to modify |
|
|
modules_to_swap (Dict[str, torch.nn.Module]): Mapping from module fqn to |
|
|
eager module to swap with. The specified module fqn should have also |
|
|
been specified in the `preserve_module_call_signature` argument to |
|
|
torch.export so that we know how to restore the calling convention |
|
|
to this argument. |
|
|
run_with_interpreter: Whether or not to run the graph using |
|
|
fx.Interpreter. Setting to true will help result in better error |
|
|
messages and easier debugging, but it has found to result in a QPS |
|
|
drop. |
|
|
""" |
|
|
module_call_graph = { |
|
|
entry.fqn: entry.signature for entry in ep.module_call_graph if entry.signature |
|
|
} |
|
|
|
|
|
gm = ep.module() |
|
|
gm.validate_inputs = False |
|
|
gm.graph.eliminate_dead_code() |
|
|
assert isinstance(gm, torch.fx.GraphModule) |
|
|
_fix_input_output_signature(gm, ep.module_call_graph[0].signature) |
|
|
|
|
|
gm.module_call_graph = ep.module_call_graph |
|
|
gm.train = types.MethodType(type(gm).train, gm) |
|
|
gm.eval = types.MethodType(type(gm).eval, gm) |
|
|
|
|
|
assert isinstance(gm, torch.fx.GraphModule) |
|
|
gm = _swap_module_helper(gm, modules_to_swap, module_call_graph) |
|
|
|
|
|
return gm |
|
|
|