|
|
from collections import deque |
|
|
from typing import Any |
|
|
|
|
|
from torch.fx import Graph, map_arg, Node |
|
|
from torch.utils._ordered_set import OrderedSet |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_flat_args( |
|
|
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]] |
|
|
) -> list[Node]: |
|
|
args = list[Any]() |
|
|
map_arg((node.args, node.kwargs), args.append) |
|
|
if node in node_to_additional_deps: |
|
|
args.extend(node_to_additional_deps[node]) |
|
|
return args |
|
|
|
|
|
|
|
|
def _get_flat_args_unique( |
|
|
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]] |
|
|
) -> OrderedSet[Node]: |
|
|
args = OrderedSet[Node]() |
|
|
map_arg((node.args, node.kwargs), args.add) |
|
|
if node in node_to_additional_deps: |
|
|
args.update(node_to_additional_deps[node]) |
|
|
return args |
|
|
|
|
|
|
|
|
def _detect_cycles( |
|
|
graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]] |
|
|
) -> str: |
|
|
current_path: deque[Node] = deque() |
|
|
current_path_set: set[Node] = set() |
|
|
pending: deque[tuple[Node, Node]] = deque() |
|
|
|
|
|
def add_to_current_path(node: Node) -> None: |
|
|
current_path.append(node) |
|
|
current_path_set.add(node) |
|
|
|
|
|
def pop_current_path() -> None: |
|
|
node = current_path.pop() |
|
|
current_path_set.remove(node) |
|
|
|
|
|
def current_path_head() -> Node: |
|
|
return current_path[-1] |
|
|
|
|
|
for origin in graph.find_nodes(op="output"): |
|
|
current_path.clear() |
|
|
current_path_set.clear() |
|
|
add_to_current_path(origin) |
|
|
for child in _get_flat_args_unique(origin, node_to_additional_deps): |
|
|
pending.append((child, origin)) |
|
|
|
|
|
while pending: |
|
|
cur_node, parent = pending.pop() |
|
|
|
|
|
|
|
|
while current_path and current_path_head() != parent: |
|
|
pop_current_path() |
|
|
|
|
|
if not isinstance(cur_node, Node): |
|
|
continue |
|
|
|
|
|
if cur_node in current_path_set: |
|
|
current_path.append(cur_node) |
|
|
return f"cycle detected in path: {current_path}" |
|
|
|
|
|
add_to_current_path(cur_node) |
|
|
|
|
|
for child in _get_flat_args_unique(cur_node, node_to_additional_deps): |
|
|
pending.append((child, cur_node)) |
|
|
|
|
|
return "no cycle detected" |
|
|
|