Spaces:
Running
Running
| """ | |
| Contains tests and a prototype implementation for the fanout algorithm in | |
| the LLVM refprune pass. | |
| """ | |
| try: | |
| from graphviz import Digraph | |
| except ImportError: | |
| pass | |
| from collections import defaultdict | |
| # The entry block. It's always the same. | |
| ENTRY = "A" | |
| # The following caseNN() functions returns a 3-tuple of | |
| # (nodes, edges, expected). | |
| # `nodes` maps BB nodes to incref/decref inside the block. | |
| # `edges` maps BB nodes to their successor BB. | |
| # `expected` maps BB-node with incref to a set of BB-nodes with the decrefs, or | |
| # the value can be None, indicating invalid prune. | |
| def case1(): | |
| edges = { | |
| "A": ["B"], | |
| "B": ["C", "D"], | |
| "C": [], | |
| "D": ["E", "F"], | |
| "E": ["G"], | |
| "F": [], | |
| "G": ["H", "I"], | |
| "I": ["G", "F"], | |
| "H": ["J", "K"], | |
| "J": ["L", "M"], | |
| "K": [], | |
| "L": ["Z"], | |
| "M": ["Z", "O", "P"], | |
| "O": ["Z"], | |
| "P": ["Z"], | |
| "Z": [], | |
| } | |
| nodes = defaultdict(list) | |
| nodes["D"] = ["incref"] | |
| nodes["H"] = ["decref"] | |
| nodes["F"] = ["decref", "decref"] | |
| expected = {"D": {"H", "F"}} | |
| return nodes, edges, expected | |
| def case2(): | |
| edges = { | |
| "A": ["B", "C"], | |
| "B": ["C"], | |
| "C": [], | |
| } | |
| nodes = defaultdict(list) | |
| nodes["A"] = ["incref"] | |
| nodes["B"] = ["decref"] | |
| nodes["C"] = ["decref"] | |
| expected = {"A": None} | |
| return nodes, edges, expected | |
| def case3(): | |
| nodes, edges, _ = case1() | |
| # adds an invalid edge | |
| edges["H"].append("F") | |
| expected = {"D": None} | |
| return nodes, edges, expected | |
| def case4(): | |
| nodes, edges, _ = case1() | |
| # adds an invalid edge | |
| edges["H"].append("E") | |
| expected = {"D": None} | |
| return nodes, edges, expected | |
| def case5(): | |
| nodes, edges, _ = case1() | |
| # adds backedge to go before incref | |
| edges["B"].append("I") | |
| expected = {"D": None} | |
| return nodes, edges, expected | |
| def case6(): | |
| nodes, edges, _ = case1() | |
| # adds backedge to go before incref | |
| edges["I"].append("B") | |
| expected = {"D": None} | |
| return nodes, edges, expected | |
| def case7(): | |
| nodes, edges, _ = case1() | |
| # adds forward jump outside | |
| edges["I"].append("M") | |
| expected = {"D": None} | |
| return nodes, edges, expected | |
| def case8(): | |
| edges = { | |
| "entry:": ["A"], | |
| "A": ["B", "C"], | |
| "B": ["C"], | |
| "C": [], | |
| } | |
| nodes = defaultdict(list) | |
| nodes["A"] = ["incref"] | |
| nodes["C"] = ["decref"] | |
| expected = {"A": {"C"}} | |
| return nodes, edges, expected | |
| def case9(): | |
| nodes, edges, _ = case8() | |
| # adds back edge | |
| edges["C"].append("B") | |
| expected = {"A": None} | |
| return nodes, edges, expected | |
| def case10(): | |
| nodes, edges, _ = case8() | |
| # adds back edge to A | |
| edges["C"].append("A") | |
| expected = {"A": {"C"}} | |
| return nodes, edges, expected | |
| def case11(): | |
| nodes, edges, _ = case8() | |
| edges["C"].append("D") | |
| edges["D"] = [] | |
| expected = {"A": {"C"}} | |
| return nodes, edges, expected | |
| def case12(): | |
| nodes, edges, _ = case8() | |
| edges["C"].append("D") | |
| edges["D"] = ["A"] | |
| expected = {"A": {"C"}} | |
| return nodes, edges, expected | |
| def case13(): | |
| nodes, edges, _ = case8() | |
| edges["C"].append("D") | |
| edges["D"] = ["B"] | |
| expected = {"A": None} | |
| return nodes, edges, expected | |
| def make_predecessor_map(edges): | |
| d = defaultdict(set) | |
| for src, outgoings in edges.items(): | |
| for dst in outgoings: | |
| d[dst].add(src) | |
| return d | |
| class FanoutAlgorithm: | |
| def __init__(self, nodes, edges, verbose=False): | |
| self.nodes = nodes | |
| self.edges = edges | |
| self.rev_edges = make_predecessor_map(edges) | |
| self.print = print if verbose else self._null_print | |
| def run(self): | |
| return self.find_fanout_in_function() | |
| def _null_print(self, *args, **kwargs): | |
| pass | |
| def find_fanout_in_function(self): | |
| got = {} | |
| for cur_node in self.edges: | |
| for incref in (x for x in self.nodes[cur_node] if x == "incref"): | |
| decref_blocks = self.find_fanout(cur_node) | |
| self.print(">>", cur_node, "===", decref_blocks) | |
| got[cur_node] = decref_blocks | |
| return got | |
| def find_fanout(self, head_node): | |
| decref_blocks = self.find_decref_candidates(head_node) | |
| self.print("candidates", decref_blocks) | |
| if not decref_blocks: | |
| return None | |
| if not self.verify_non_overlapping( | |
| head_node, decref_blocks, entry=ENTRY | |
| ): | |
| return None | |
| return set(decref_blocks) | |
| def verify_non_overlapping(self, head_node, decref_blocks, entry): | |
| self.print("verify_non_overlapping".center(80, "-")) | |
| # reverse walk for each decref_blocks | |
| # they should end at head_node | |
| todo = list(decref_blocks) | |
| while todo: | |
| cur_node = todo.pop() | |
| visited = set() | |
| workstack = [cur_node] | |
| del cur_node | |
| while workstack: | |
| cur_node = workstack.pop() | |
| self.print("cur_node", cur_node, "|", workstack) | |
| if cur_node in visited: | |
| continue # skip | |
| if cur_node == entry: | |
| # Entry node | |
| self.print( | |
| "!! failed because we arrived at entry", cur_node | |
| ) | |
| return False | |
| visited.add(cur_node) | |
| # check all predecessors | |
| self.print( | |
| f" {cur_node} preds {self.get_predecessors(cur_node)}" | |
| ) | |
| for pred in self.get_predecessors(cur_node): | |
| if pred in decref_blocks: | |
| # reject because there's a predecessor in decref_blocks | |
| self.print( | |
| "!! reject because predecessor in decref_blocks" | |
| ) | |
| return False | |
| if pred != head_node: | |
| workstack.append(pred) | |
| return True | |
| def get_successors(self, node): | |
| return tuple(self.edges[node]) | |
| def get_predecessors(self, node): | |
| return tuple(self.rev_edges[node]) | |
| def has_decref(self, node): | |
| return "decref" in self.nodes[node] | |
| def walk_child_for_decref( | |
| self, cur_node, path_stack, decref_blocks, depth=10 | |
| ): | |
| indent = " " * len(path_stack) | |
| self.print(indent, "walk", path_stack, cur_node) | |
| if depth <= 0: | |
| return False # missing | |
| if cur_node in path_stack: | |
| if cur_node == path_stack[0]: | |
| return False # reject interior node backedge | |
| return True # skip | |
| if self.has_decref(cur_node): | |
| decref_blocks.add(cur_node) | |
| self.print(indent, "found decref") | |
| return True | |
| depth -= 1 | |
| path_stack += (cur_node,) | |
| found = False | |
| for child in self.get_successors(cur_node): | |
| if not self.walk_child_for_decref( | |
| child, path_stack, decref_blocks | |
| ): | |
| found = False | |
| break | |
| else: | |
| found = True | |
| self.print(indent, f"ret {found}") | |
| return found | |
| def find_decref_candidates(self, cur_node): | |
| # Forward pass | |
| self.print("find_decref_candidates".center(80, "-")) | |
| path_stack = (cur_node,) | |
| found = False | |
| decref_blocks = set() | |
| for child in self.get_successors(cur_node): | |
| if not self.walk_child_for_decref( | |
| child, path_stack, decref_blocks | |
| ): | |
| found = False | |
| break | |
| else: | |
| found = True | |
| if not found: | |
| return set() | |
| else: | |
| return decref_blocks | |
| def check_once(): | |
| nodes, edges, expected = case13() | |
| # Render graph | |
| G = Digraph() | |
| for node in edges: | |
| G.node(node, shape="rect", label=f"{node}\n" + r"\l".join(nodes[node])) | |
| for node, children in edges.items(): | |
| for child in children: | |
| G.edge(node, child) | |
| G.view() | |
| algo = FanoutAlgorithm(nodes, edges, verbose=True) | |
| got = algo.run() | |
| assert expected == got | |
| def check_all(): | |
| for k, fn in list(globals().items()): | |
| if k.startswith("case"): | |
| print(f"{fn}".center(80, "-")) | |
| nodes, edges, expected = fn() | |
| algo = FanoutAlgorithm(nodes, edges) | |
| got = algo.run() | |
| assert expected == got | |
| print("ALL PASSED") | |
| if __name__ == "__main__": | |
| # check_once() | |
| check_all() | |