File size: 482 Bytes
77d2ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# visualize.py
import networkx as nx
import matplotlib.pyplot as plt


def draw_tree(tree):
    G = nx.DiGraph()

    def traverse(node, parent=None):
        G.add_node(node["title"])

        if parent:
            G.add_edge(parent, node["title"])

        for child in node.get("children", []):
            traverse(child, node["title"])

    traverse(tree)

    pos = nx.spring_layout(G)
    nx.draw(G, pos, with_labels=True, node_size=2000)
    plt.show()