File size: 4,332 Bytes
121ffc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Dict, Any, Set, Tuple, List
from streamlit_agraph import Node, Edge

def build_tree_structure(graph: Dict[str, Any], root_id: int) -> Dict[int, Dict[str, Any]]:
    nodes = graph.get("nodes", {})
    tree = {}

    def calculate_depths(node_id: int, visited: Set[int], depth: int = 0) -> Dict[int, int]:
        if node_id in visited:
            return {}
        visited.add(node_id)
        depths = {node_id: depth}
        node = nodes.get(node_id, {})
        for advisor_id in node.get("advisors", []):
            if advisor_id in nodes:
                advisor_depths = calculate_depths(advisor_id, visited, depth + 1)
                depths.update(advisor_depths)
        return depths

    depths = calculate_depths(root_id, set())

    for node_id, node in nodes.items():
        if node_id in depths:
            tree[node_id] = {
                **node,
                "depth": depths[node_id],
                "children": [],
                "advisors": node.get("advisors", [])
            }

    for node_id, node in tree.items():
        for advisor_id in node["advisors"]:
            if advisor_id in tree:
                tree[advisor_id]["children"].append(node_id)

    return tree

def create_hierarchical_view(
    graph: Dict[str, Any],
    root_id: int,
    max_depth: int = None
) -> Tuple[List[Node], List[Edge]]:
    tree = build_tree_structure(graph, root_id)
    nodes_list = []
    edges_list = []

    colors = ["#ff6b6b", "#4ecdc4", "#45b7d1", "#96ceb4", "#feca57", "#ff9ff3", "#54a0ff"]
    nodes_by_depth = {}
    for node_id, node in tree.items():
        if max_depth is None or node["depth"] <= max_depth:
            depth = node["depth"]
            if depth not in nodes_by_depth:
                nodes_by_depth[depth] = []
            nodes_by_depth[depth].append((node_id, node))

    for depth in nodes_by_depth:
        nodes_by_depth[depth].sort(key=lambda x: x[1].get('year') or 1400)

    for depth in sorted(nodes_by_depth.keys()):
        depth_nodes = nodes_by_depth[depth]
        for i, (node_id, node) in enumerate(depth_nodes):
            color = colors[depth % len(colors)]
            name = node.get("name", str(node_id))
            year_str = f" ({node.get('year')})" if node.get('year') is not None else ""
            label = f"{name}{year_str}"
            year = node.get('year') or 1500
            base_y = depth * 300
            year_offset = (year - 1400) * 0.2
            x_pos = i * 180 + (depth * 20)
            y_pos = base_y + year_offset
            ag_node = Node(
                id=str(node_id),
                label=label,
                size=25 if node_id == root_id else 20,
                color=color,
                title=f"Name: {name}\nYear: {node.get('year', 'N/A')}\nInstitution: {node.get('institution', 'N/A')}",
                x=x_pos,
                y=y_pos,
                font={"color": "white", "size": 12}
            )
            nodes_list.append(ag_node)
            for advisor_id in node["advisors"]:
                if advisor_id in tree and (max_depth is None or tree[advisor_id]["depth"] <= max_depth):
                    edge = Edge(
                        source=str(advisor_id),
                        target=str(node_id),
                        color="#666666"
                    )
                    edges_list.append(edge)
    return nodes_list, edges_list

def tree_to_dot(graph: Dict[str, Any]) -> str:
    nodes = graph.get("nodes", {})
    lines = [
        "digraph G {",
        "  rankdir=TB;",
        '  node [shape=box, style="rounded,filled", fillcolor=lightyellow];',
        '  edge [arrowhead=vee];'
    ]
    for node_id, node in nodes.items():
        name = node.get("name", str(node_id))
        year_str = f" ({node.get('year')})" if node.get('year') is not None else " (Year Unknown)"
        label = f"{name}{year_str}"
        tooltip = f"ID: {node_id}\\nName: {name}\\nYear: {node.get('year', 'N/A')}\\nInstitution: {node.get('institution', 'N/A')}"
        lines.append(f'  "{node_id}" [label="{label}", tooltip="{tooltip}"];')
    for node_id, node in nodes.items():
        for adv_id in node.get("advisors", []):
            if adv_id in nodes:
                lines.append(f'  "{adv_id}" -> "{node_id}";')
    lines.append("}")
    return "\n".join(lines)