|
|
|
|
|
import collections
|
|
|
import itertools
|
|
|
import logging
|
|
|
import operator
|
|
|
from collections.abc import Iterable, Sequence
|
|
|
from typing import Optional
|
|
|
|
|
|
from torch.fx.graph_module import GraphModule
|
|
|
from torch.fx.node import _get_qualified_name, Node
|
|
|
from torch.fx.passes.operator_support import OperatorSupportBase
|
|
|
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
class Partition:
|
|
|
def __init__(
|
|
|
self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None
|
|
|
):
|
|
|
self.id = id
|
|
|
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return str(self.nodes)
|
|
|
|
|
|
def add_node(self, node: Node):
|
|
|
self.nodes.update({node: None})
|
|
|
|
|
|
def remove_node(self, node: Node):
|
|
|
del self.nodes[node]
|
|
|
|
|
|
def size(self):
|
|
|
return len(self.nodes)
|
|
|
|
|
|
|
|
|
class _DependencyViewer:
|
|
|
def __init__(self, graph_module: GraphModule):
|
|
|
self.downstreams = collections.defaultdict(set)
|
|
|
|
|
|
for node in reversed(graph_module.graph.nodes):
|
|
|
for output_node in node.users:
|
|
|
|
|
|
self.downstreams[node].add(output_node)
|
|
|
self.downstreams[node].update(self.downstreams[output_node])
|
|
|
|
|
|
def downstreams_of(self, node: Node) -> set[Node]:
|
|
|
return self.downstreams[node]
|
|
|
|
|
|
|
|
|
class CapabilityBasedPartitioner:
|
|
|
def __init__(
|
|
|
self,
|
|
|
graph_module: GraphModule,
|
|
|
operator_support: OperatorSupportBase,
|
|
|
allows_single_node_partition: bool = False,
|
|
|
non_compute_ops: Optional[Sequence[str]] = None,
|
|
|
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
|
|
|
) -> None:
|
|
|
self.graph_module = graph_module
|
|
|
self.operator_support = operator_support
|
|
|
self.allows_single_node_partition = allows_single_node_partition
|
|
|
self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
|
|
|
self.allowed_single_node_partition_ops = (
|
|
|
allowed_single_node_partition_ops
|
|
|
if allowed_single_node_partition_ops is not None
|
|
|
else []
|
|
|
)
|
|
|
self.dependency_viewer = _DependencyViewer(graph_module)
|
|
|
|
|
|
def _is_node_supported(self, node: Node) -> bool:
|
|
|
return self.operator_support.is_node_supported(
|
|
|
dict(self.graph_module.named_modules()), node
|
|
|
)
|
|
|
|
|
|
def propose_partitions(self) -> list[Partition]:
|
|
|
|
|
|
|
|
|
|
|
|
partition_map: dict[int, set] = collections.defaultdict(set)
|
|
|
|
|
|
|
|
|
assignment: dict[Node, int] = {}
|
|
|
partitions_by_id: dict[
|
|
|
int, Partition
|
|
|
] = {}
|
|
|
nodes_order: dict[
|
|
|
Node, int
|
|
|
] = {}
|
|
|
partitions_order: dict[
|
|
|
int, int
|
|
|
] = {}
|
|
|
partition_users: dict[
|
|
|
int, set
|
|
|
] = {}
|
|
|
new_partition_id = itertools.count()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def maybe_merge_partition(self_id: int, other_id: int):
|
|
|
|
|
|
self_nodes = partitions_by_id[self_id].nodes
|
|
|
other_nodes = partitions_by_id[other_id].nodes
|
|
|
|
|
|
def dfs_iter_find_cycle(all_user_nodes: set[Node]):
|
|
|
for user_node in all_user_nodes:
|
|
|
visited_partition_ids = set()
|
|
|
|
|
|
for path_node in self.dependency_viewer.downstreams_of(user_node):
|
|
|
|
|
|
|
|
|
if path_node in self_nodes or path_node in other_nodes:
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if path_node in assignment:
|
|
|
partition_id = assignment[path_node]
|
|
|
|
|
|
|
|
|
if partition_id in visited_partition_ids:
|
|
|
continue
|
|
|
p_map = partition_map[partition_id]
|
|
|
if self_id in p_map or other_id in p_map:
|
|
|
return True
|
|
|
|
|
|
visited_partition_ids.add(partition_id)
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
all_user_nodes = partition_users[self_id] | partition_users[other_id]
|
|
|
all_user_nodes.difference_update(other_nodes, self_nodes)
|
|
|
|
|
|
|
|
|
if dfs_iter_find_cycle(all_user_nodes):
|
|
|
|
|
|
|
|
|
return self_id, False
|
|
|
|
|
|
|
|
|
merge_id, removed_id = self_id, other_id
|
|
|
if len(self_nodes) < len(other_nodes):
|
|
|
merge_id, removed_id = removed_id, merge_id
|
|
|
|
|
|
|
|
|
partitions_by_id[merge_id].nodes.update(partitions_by_id[removed_id].nodes)
|
|
|
|
|
|
for node in partitions_by_id[removed_id].nodes:
|
|
|
assignment[node] = merge_id
|
|
|
|
|
|
del partitions_by_id[removed_id]
|
|
|
|
|
|
partitions_order[merge_id] = min(
|
|
|
partitions_order[merge_id], partitions_order[removed_id]
|
|
|
)
|
|
|
del partitions_order[removed_id]
|
|
|
|
|
|
partition_map[merge_id] = partition_map[merge_id].union(
|
|
|
partition_map[removed_id]
|
|
|
)
|
|
|
del partition_map[removed_id]
|
|
|
|
|
|
partition_users[merge_id] = all_user_nodes
|
|
|
del partition_users[removed_id]
|
|
|
|
|
|
return merge_id, True
|
|
|
|
|
|
def merge_single_node(node: Node, id: Optional[int]):
|
|
|
def _update_partition_map(node: Node, id: int):
|
|
|
|
|
|
|
|
|
for user_node in node.users:
|
|
|
target_id = assignment.get(user_node, None)
|
|
|
if target_id is not None:
|
|
|
partition_map[id].add(target_id)
|
|
|
partition_map[id].update(partition_map[target_id])
|
|
|
|
|
|
if node in assignment:
|
|
|
partitions_by_id[assignment[node]].remove_node(node)
|
|
|
|
|
|
if id is None:
|
|
|
assignment.pop(node)
|
|
|
elif id not in partitions_by_id:
|
|
|
assignment[node] = id
|
|
|
partitions_by_id[id] = Partition(id=id, nodes=[node])
|
|
|
partition_users[id] = set(node.users)
|
|
|
_update_partition_map(node, id)
|
|
|
else:
|
|
|
assignment[node] = id
|
|
|
partitions_by_id[id].add_node(node)
|
|
|
|
|
|
logger.debug("Proposing partitions...")
|
|
|
|
|
|
for node in reversed(self.graph_module.graph.nodes):
|
|
|
|
|
|
merge_candidates: dict[int, None] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._is_node_supported(node) and node not in assignment:
|
|
|
partition_id = next(new_partition_id)
|
|
|
nodes_order[node] = partition_id
|
|
|
partitions_order[partition_id] = partition_id
|
|
|
merge_single_node(node, partition_id)
|
|
|
merge_candidates[partition_id] = None
|
|
|
|
|
|
|
|
|
for partition_id, _ in sorted(
|
|
|
partitions_order.items(), key=operator.itemgetter(1)
|
|
|
):
|
|
|
merge_candidates[partition_id] = None
|
|
|
|
|
|
merge_candidates_list = list(merge_candidates.keys())
|
|
|
if len(merge_candidates_list) > 1:
|
|
|
self_id = merge_candidates_list[0]
|
|
|
for other_id in merge_candidates_list[1:]:
|
|
|
|
|
|
|
|
|
self_id, _ = maybe_merge_partition(self_id, other_id)
|
|
|
|
|
|
|
|
|
logger.debug("Reassigning getitem nodes to its producer node's partition...")
|
|
|
nodes_reassignment: dict[Node, int] = {}
|
|
|
for node in self.graph_module.graph.nodes:
|
|
|
is_tuple_output = True
|
|
|
for user in node.users:
|
|
|
if (
|
|
|
user.op != "call_function"
|
|
|
or _get_qualified_name(user.target) != "_operator.getitem"
|
|
|
):
|
|
|
is_tuple_output = False
|
|
|
break
|
|
|
|
|
|
|
|
|
if is_tuple_output:
|
|
|
id = assignment.get(node, None)
|
|
|
for user in node.users:
|
|
|
if assignment.get(user, None) != id:
|
|
|
nodes_reassignment[user] = id
|
|
|
for node, id in nodes_reassignment.items():
|
|
|
merge_single_node(node, id)
|
|
|
|
|
|
|
|
|
if not self.allows_single_node_partition:
|
|
|
logger.debug("Filtering out single node partitions...")
|
|
|
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
|
|
|
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
|
|
|
partitions_to_remove: list[int] = []
|
|
|
for id, partition in partitions_by_id.items():
|
|
|
compute_node_count = 0
|
|
|
for node in partition.nodes:
|
|
|
if node.op == "call_function":
|
|
|
assert callable(node.target)
|
|
|
if _get_qualified_name(node.target) not in non_compute_ops:
|
|
|
compute_node_count += 1
|
|
|
if (
|
|
|
_get_qualified_name(node.target)
|
|
|
in self.allowed_single_node_partition_ops
|
|
|
):
|
|
|
compute_node_count += 1
|
|
|
if compute_node_count <= 1:
|
|
|
partitions_to_remove.append(id)
|
|
|
for id in partitions_to_remove:
|
|
|
del partitions_by_id[id]
|
|
|
|
|
|
logger.debug("Partitions proposed:")
|
|
|
for id, partition in partitions_by_id.items():
|
|
|
logger.debug(
|
|
|
"partition #%s: %s", id, [node.name for node in partition.nodes]
|
|
|
)
|
|
|
|
|
|
return [
|
|
|
partition for partition in partitions_by_id.values() if partition.size() > 0
|
|
|
]
|
|
|
|
|
|
def fuse_partitions(
|
|
|
self, partitions: list[Partition], prefix: str = "fused_"
|
|
|
) -> GraphModule:
|
|
|
logger.debug("Fusing partitions...")
|
|
|
|
|
|
return fuse_by_partitions(
|
|
|
self.graph_module,
|
|
|
[partition.nodes for partition in partitions],
|
|
|
prefix=prefix,
|
|
|
)
|
|
|
|
|
|
|
|
|
def remove_bookend_non_compute_ops(self, partitions: list[Partition]):
|
|
|
non_compute_ops = set(self.non_compute_ops)
|
|
|
|
|
|
def is_non_compute_node(node: Node):
|
|
|
return (
|
|
|
node.op == "call_function"
|
|
|
and _get_qualified_name(node.target) in non_compute_ops
|
|
|
)
|
|
|
|
|
|
|
|
|
transparent_input_nodes: dict[Node, bool] = {}
|
|
|
transparent_output_nodes: dict[Node, bool] = {}
|
|
|
|
|
|
def is_transparent_input_node(
|
|
|
node: Node, partition: set[Node], removed_nodes: set[Node]
|
|
|
):
|
|
|
if (
|
|
|
node.op == "placeholder"
|
|
|
or (node not in partition)
|
|
|
or (node in removed_nodes)
|
|
|
):
|
|
|
return True
|
|
|
if node in transparent_input_nodes:
|
|
|
return transparent_input_nodes[node]
|
|
|
if is_non_compute_node(node):
|
|
|
for input_n in node.all_input_nodes:
|
|
|
if not is_transparent_input_node(input_n, partition, removed_nodes):
|
|
|
transparent_input_nodes[node] = False
|
|
|
return False
|
|
|
transparent_input_nodes[node] = True
|
|
|
return True
|
|
|
transparent_input_nodes[node] = False
|
|
|
return False
|
|
|
|
|
|
def is_transparent_output_node(
|
|
|
node: Node, partition: set[Node], removed_nodes: set[Node]
|
|
|
):
|
|
|
if (
|
|
|
node.op == "placeholder"
|
|
|
or (node not in partition)
|
|
|
or (node in removed_nodes)
|
|
|
):
|
|
|
return True
|
|
|
if node in transparent_output_nodes:
|
|
|
return transparent_output_nodes[node]
|
|
|
if is_non_compute_node(node):
|
|
|
for output_n in node.users:
|
|
|
if not is_transparent_output_node(
|
|
|
output_n, partition, removed_nodes
|
|
|
):
|
|
|
transparent_output_nodes[node] = False
|
|
|
return False
|
|
|
transparent_output_nodes[node] = True
|
|
|
return True
|
|
|
transparent_output_nodes[node] = False
|
|
|
return False
|
|
|
|
|
|
for partition in partitions:
|
|
|
|
|
|
|
|
|
|
|
|
remove_node: set[Node] = set()
|
|
|
for node in partition.nodes:
|
|
|
if is_non_compute_node(node) and (
|
|
|
is_transparent_input_node(node, set(partition.nodes), remove_node)
|
|
|
or is_transparent_output_node(
|
|
|
node, set(partition.nodes), remove_node
|
|
|
)
|
|
|
):
|
|
|
remove_node.add(node)
|
|
|
|
|
|
if len(remove_node) != 0:
|
|
|
for node in remove_node:
|
|
|
partition.nodes.pop(node, None)
|
|
|
|
|
|
def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule:
|
|
|
partitions = self.propose_partitions()
|
|
|
fused_gm = self.fuse_partitions(partitions, prefix=prefix)
|
|
|
return fused_gm
|
|
|
|