|
|
"""
|
|
|
This module implements graph deduplication functionality for TorchDynamo's optimization pipeline.
|
|
|
Graph deduplication identifies identical subgraphs in the computational graph and merges them
|
|
|
to reduce redundancy and improve performance. The process involves analyzing regions of the graph,
|
|
|
identifying structurally equivalent regions, and replacing them with a single shared implementation.
|
|
|
This optimization is particularly effective for models with repeated patterns or similar computational
|
|
|
structures across different parts of the network.
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
import operator
|
|
|
from collections import defaultdict
|
|
|
from collections.abc import Generator, Iterable
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.fx
|
|
|
from torch._dynamo import config
|
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
from .graph_region_tracker import Node, Region
|
|
|
from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UsageIndex = tuple[int, int]
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None
|
|
|
|
|
|
|
|
|
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]:
|
|
|
"""
|
|
|
This is the main entry point for applying the graph deduplication pass. \
|
|
|
Deduplication occurs in two phases:
|
|
|
1. Subgraph creation:
|
|
|
Subgraph creation works by taking one representative region from each region \
|
|
|
group and creating a subgraph from it, which will then be used to replace all regions \
|
|
|
in the group. This is implemented by first copying all nodes of the region to the new \
|
|
|
subgraph and then finding all inputs which are not within the region and creating placeholders \
|
|
|
for them. For the outputs, all regions in a region group need to be scanned to ensure the \
|
|
|
largest set of outputs is found, and then an output node is created which returns \
|
|
|
a tuple of all outputs.
|
|
|
|
|
|
2. Graph replacement:
|
|
|
To replace each region with the extracted subgraph, the node index in the region \
|
|
|
and argument index within the node's flattened args and kwargs are recorded once during \
|
|
|
subgraph creation. This allows us to determine which (external to the region) nodes and \
|
|
|
in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \
|
|
|
for each output, and all nodes in the region with external outputs are replaced by the proper \
|
|
|
getitem node. Finally, all original nodes are erased (there should be no uses of these \
|
|
|
left in the graph).
|
|
|
|
|
|
The deduplication mutates the output_graph argument in place.
|
|
|
|
|
|
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
|
|
|
when they are created in output_graph.
|
|
|
"""
|
|
|
|
|
|
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
|
|
|
output_graph.graph
|
|
|
)
|
|
|
node_to_mutated_arg_positions = (
|
|
|
output_graph.region_tracker.node_to_mutated_arg_positions
|
|
|
)
|
|
|
node_to_additional_deps = _populate_additional_deps(
|
|
|
output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
|
|
|
)
|
|
|
|
|
|
sub_gms: dict[str, torch.fx.GraphModule] = {}
|
|
|
|
|
|
for region_group in duplicated_region_groups:
|
|
|
inds_with_external_users = _get_all_output_indices(region_group)
|
|
|
region = region_group[0]
|
|
|
(
|
|
|
subgraph,
|
|
|
external_node_usages,
|
|
|
) = _create_subgraph(region, inds_with_external_users)
|
|
|
|
|
|
|
|
|
if not list(external_node_usages):
|
|
|
continue
|
|
|
|
|
|
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
|
|
|
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
|
|
|
sub_gms[subgraph_name] = sub_gm
|
|
|
with output_graph.graph.inserting_before():
|
|
|
get_subgraph_node = output_graph.graph.create_node(
|
|
|
"get_attr", subgraph_name, (), {}
|
|
|
)
|
|
|
|
|
|
for region in region_group:
|
|
|
_replace_region_with_subgraph(
|
|
|
output_graph.graph,
|
|
|
region,
|
|
|
get_subgraph_node,
|
|
|
external_node_usages,
|
|
|
inds_with_external_users,
|
|
|
subgraph_name,
|
|
|
node_to_additional_deps,
|
|
|
node_to_mutated_arg_positions,
|
|
|
)
|
|
|
|
|
|
|
|
|
global last_node_to_additional_deps
|
|
|
last_node_to_additional_deps = node_to_additional_deps
|
|
|
|
|
|
_stable_topological_sort(
|
|
|
output_graph.graph,
|
|
|
node_to_additional_deps,
|
|
|
)
|
|
|
return sub_gms
|
|
|
|
|
|
|
|
|
def _replace_region_with_subgraph(
|
|
|
graph: torch.fx.Graph,
|
|
|
region: Region,
|
|
|
get_subgraph_node: Node,
|
|
|
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
|
|
inds_with_external_users: list[int],
|
|
|
subgraph_name: str,
|
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
|
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
|
|
) -> None:
|
|
|
sub_args = []
|
|
|
for usages in external_node_usages:
|
|
|
node_ind, usage_ind = next(iter(usages))
|
|
|
node = region[node_ind]
|
|
|
flattened_args_kwargs = _get_flat_args(node, {})
|
|
|
for user_ind, node_usage_ind in usages:
|
|
|
user = region[user_ind]
|
|
|
if user in node_to_mutated_arg_positions:
|
|
|
if node_usage_ind in node_to_mutated_arg_positions[user]:
|
|
|
log.debug(
|
|
|
"NYI: Failed to substitute region %s due to mutation", region
|
|
|
)
|
|
|
return
|
|
|
sub_args.append(flattened_args_kwargs[usage_ind])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _has_aliasing(region, sub_args, inds_with_external_users):
|
|
|
return
|
|
|
|
|
|
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
|
|
|
|
|
invoke_subgraph_node = graph.create_node(
|
|
|
"call_function",
|
|
|
torch.ops.higher_order.invoke_subgraph,
|
|
|
invoke_args,
|
|
|
{},
|
|
|
)
|
|
|
for ind, external_user_ind in enumerate(inds_with_external_users):
|
|
|
node = region[external_user_ind]
|
|
|
subgraph_output = graph.create_node(
|
|
|
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
|
|
|
)
|
|
|
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
|
|
|
|
|
|
|
|
for node in reversed(region):
|
|
|
graph.erase_node(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_to_additional_deps.pop(node, None)
|
|
|
for deps in node_to_additional_deps.values():
|
|
|
try:
|
|
|
deps.remove(node)
|
|
|
deps.add(invoke_subgraph_node)
|
|
|
except KeyError:
|
|
|
pass
|
|
|
|
|
|
if config.graph_deduplication_lint:
|
|
|
print(_detect_cycles(graph, node_to_additional_deps))
|
|
|
_stable_topological_sort(graph, node_to_additional_deps)
|
|
|
graph.lint()
|
|
|
|
|
|
|
|
|
def _get_external_inputs(
|
|
|
region: Region,
|
|
|
) -> dict[Node, OrderedSet[UsageIndex]]:
|
|
|
external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet)
|
|
|
region_unique = set(region)
|
|
|
for node_ind, node in enumerate(region):
|
|
|
flattened_args_kwargs = _get_flat_args(node, {})
|
|
|
for arg_ind, in_node in enumerate(flattened_args_kwargs):
|
|
|
if isinstance(in_node, Node) and in_node not in region_unique:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
external_node_to_usages[in_node].add((node_ind, arg_ind))
|
|
|
|
|
|
return external_node_to_usages
|
|
|
|
|
|
|
|
|
def _get_all_output_indices(regions: list[Region]) -> list[int]:
|
|
|
|
|
|
|
|
|
inds_with_external_users: set[int] = set()
|
|
|
for region in regions:
|
|
|
_get_inds_with_external_users(region, inds_with_external_users)
|
|
|
|
|
|
return sorted(inds_with_external_users)
|
|
|
|
|
|
|
|
|
def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None:
|
|
|
for ind, node in enumerate(region):
|
|
|
for user in node.users:
|
|
|
if user not in region:
|
|
|
if ind not in inds_unique:
|
|
|
inds_unique.add(ind)
|
|
|
|
|
|
|
|
|
def _copy_nodes_and_remap_inputs(
|
|
|
subgraph: torch.fx.Graph, region: Region
|
|
|
) -> list[OrderedSet[UsageIndex]]:
|
|
|
external_input_to_usages = _get_external_inputs(region)
|
|
|
external_node_usages = list[OrderedSet[UsageIndex]]()
|
|
|
region_to_subgraph_node = {}
|
|
|
for node, usage_indices in external_input_to_usages.items():
|
|
|
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
|
|
region_to_subgraph_node[node] = placeholder
|
|
|
external_node_usages.append(usage_indices)
|
|
|
|
|
|
def map_arg(node: Node) -> Node:
|
|
|
if node in region_to_subgraph_node:
|
|
|
return region_to_subgraph_node[node]
|
|
|
else:
|
|
|
return node
|
|
|
|
|
|
for node in region:
|
|
|
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
|
|
|
region_to_subgraph_node[node] = subgraph_node
|
|
|
|
|
|
return external_node_usages
|
|
|
|
|
|
|
|
|
def _create_subgraph_outputs(
|
|
|
subgraph: torch.fx.Graph, inds_to_output: list[int]
|
|
|
) -> None:
|
|
|
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
|
|
|
out_tup = tuple(node_list[ind] for ind in inds_to_output)
|
|
|
subgraph.output(out_tup)
|
|
|
|
|
|
|
|
|
def _create_subgraph(
|
|
|
region: Region,
|
|
|
inds_with_external_users: list[int],
|
|
|
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
|
|
|
subgraph: torch.fx.Graph = torch.fx.Graph()
|
|
|
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
|
|
|
_create_subgraph_outputs(subgraph, inds_with_external_users)
|
|
|
return subgraph, external_node_usages
|
|
|
|
|
|
|
|
|
def _stable_topological_sort(
|
|
|
graph: torch.fx.Graph,
|
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
|
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
pending = list(reversed(graph.nodes))
|
|
|
|
|
|
|
|
|
|
|
|
ready = OrderedSet[Node]()
|
|
|
|
|
|
|
|
|
|
|
|
waiting = defaultdict(list)
|
|
|
|
|
|
|
|
|
outputs = OrderedSet[Node]()
|
|
|
|
|
|
|
|
|
|
|
|
cursor = None
|
|
|
while pending:
|
|
|
node = pending.pop()
|
|
|
|
|
|
if node.target == "output":
|
|
|
outputs.add(node)
|
|
|
assert not node.users, "output nodes should have no users"
|
|
|
continue
|
|
|
|
|
|
waiting_for = [
|
|
|
x
|
|
|
for x in _get_flat_args_unique(node, node_to_additional_deps)
|
|
|
if x not in ready
|
|
|
]
|
|
|
if waiting_for:
|
|
|
|
|
|
|
|
|
waiting[waiting_for[-1]].append(node)
|
|
|
else:
|
|
|
ready.add(node)
|
|
|
if cursor and cursor.next is not node:
|
|
|
cursor.append(node)
|
|
|
cursor = node
|
|
|
|
|
|
|
|
|
pending.extend(reversed(waiting.pop(node, ())))
|
|
|
|
|
|
ready.update(outputs)
|
|
|
assert not waiting and len(ready) == len(graph.nodes)
|
|
|
|
|
|
|
|
|
def _populate_additional_deps(
|
|
|
graph: torch.fx.Graph, node_to_mutated_arg_positions: dict[Node, OrderedSet[int]]
|
|
|
) -> dict[Node, OrderedSet[Node]]:
|
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]] = defaultdict(OrderedSet)
|
|
|
_add_mutation_dependencies(node_to_mutated_arg_positions, node_to_additional_deps)
|
|
|
_add_global_state_dependencies(graph, node_to_additional_deps)
|
|
|
return node_to_additional_deps
|
|
|
|
|
|
|
|
|
def _add_global_state_dependencies(
|
|
|
graph: torch.fx.Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
|
|
) -> None:
|
|
|
import torch.amp
|
|
|
|
|
|
all_nodes = list(graph.nodes)
|
|
|
|
|
|
|
|
|
global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast}
|
|
|
all_nodes_dep_on: list[Node] = []
|
|
|
|
|
|
def prev_cur_nodes(
|
|
|
all_nodes: list[Node],
|
|
|
) -> Generator[tuple[list[Node], Node], None, None]:
|
|
|
prev_nodes: list[Node] = []
|
|
|
next_nodes = list(reversed(all_nodes))
|
|
|
|
|
|
while next_nodes:
|
|
|
cur_node = next_nodes.pop()
|
|
|
yield prev_nodes, cur_node
|
|
|
prev_nodes.append(cur_node)
|
|
|
|
|
|
for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
|
|
|
args_unique = _get_flat_args_unique(cur_node, {})
|
|
|
new_deps = [n for n in all_nodes_dep_on if n not in args_unique]
|
|
|
|
|
|
if new_deps:
|
|
|
additional_deps = node_to_additional_deps[cur_node]
|
|
|
additional_deps.update(new_deps)
|
|
|
|
|
|
if cur_node.target in global_state_targets:
|
|
|
additional_deps = node_to_additional_deps[cur_node]
|
|
|
additional_deps.update(n for n in prev_nodes if n not in args_unique)
|
|
|
all_nodes_dep_on.append(cur_node)
|
|
|
|
|
|
|
|
|
def _add_mutation_dependencies(
|
|
|
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
|
) -> None:
|
|
|
for node, indices in node_to_mutated_arg_positions.items():
|
|
|
flat_args_kwargs = _get_flat_args(node, {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for index in indices:
|
|
|
mutated_arg = flat_args_kwargs[index]
|
|
|
for user in mutated_arg.users:
|
|
|
if user is node:
|
|
|
continue
|
|
|
elif user < node:
|
|
|
node_to_additional_deps[node].add(user)
|
|
|
elif user > node:
|
|
|
node_to_additional_deps[user].add(node)
|
|
|
|
|
|
|
|
|
def _has_aliasing(
|
|
|
region: Region, inputs: list[Node], inds_with_external_users: list[int]
|
|
|
) -> bool:
|
|
|
input_storages: dict[StorageWeakRef, Node] = dict()
|
|
|
|
|
|
for node in inputs:
|
|
|
example_value = node.meta["example_value"]
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
storage = StorageWeakRef(example_value._typed_storage())
|
|
|
if storage in input_storages:
|
|
|
|
|
|
log.debug(
|
|
|
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s",
|
|
|
region,
|
|
|
input_storages[storage],
|
|
|
node,
|
|
|
)
|
|
|
return True
|
|
|
input_storages[storage] = node
|
|
|
|
|
|
output_storages: dict[StorageWeakRef, Node] = dict()
|
|
|
for i in inds_with_external_users:
|
|
|
out_node = region[i]
|
|
|
if out_node:
|
|
|
example_value = out_node.meta["example_value"]
|
|
|
assert not isinstance(example_value, list)
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
|
storage = StorageWeakRef(example_value._typed_storage())
|
|
|
if storage in output_storages:
|
|
|
|
|
|
log.debug(
|
|
|
"NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s",
|
|
|
region,
|
|
|
output_storages[storage],
|
|
|
out_node,
|
|
|
)
|
|
|
return True
|
|
|
output_storages[storage] = out_node
|
|
|
|
|
|
intersected_storages = input_storages.keys() & output_storages.keys()
|
|
|
if len(intersected_storages) > 0:
|
|
|
|
|
|
aliased = [
|
|
|
(input_storages[s], output_storages[s]) for s in intersected_storages
|
|
|
]
|
|
|
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
|
|
|
log.debug(
|
|
|
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s",
|
|
|
region,
|
|
|
aliased,
|
|
|
)
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|