from typing import Dict, Any, Set, Tuple, List from streamlit_agraph import Node, Edge def build_tree_structure(graph: Dict[str, Any], root_id: int) -> Dict[int, Dict[str, Any]]: nodes = graph.get("nodes", {}) tree = {} def calculate_depths(node_id: int, visited: Set[int], depth: int = 0) -> Dict[int, int]: if node_id in visited: return {} visited.add(node_id) depths = {node_id: depth} node = nodes.get(node_id, {}) for advisor_id in node.get("advisors", []): if advisor_id in nodes: advisor_depths = calculate_depths(advisor_id, visited, depth + 1) depths.update(advisor_depths) return depths depths = calculate_depths(root_id, set()) for node_id, node in nodes.items(): if node_id in depths: tree[node_id] = { **node, "depth": depths[node_id], "children": [], "advisors": node.get("advisors", []) } for node_id, node in tree.items(): for advisor_id in node["advisors"]: if advisor_id in tree: tree[advisor_id]["children"].append(node_id) return tree def create_hierarchical_view( graph: Dict[str, Any], root_id: int, max_depth: int = None ) -> Tuple[List[Node], List[Edge]]: tree = build_tree_structure(graph, root_id) nodes_list = [] edges_list = [] colors = ["#ff6b6b", "#4ecdc4", "#45b7d1", "#96ceb4", "#feca57", "#ff9ff3", "#54a0ff"] nodes_by_depth = {} for node_id, node in tree.items(): if max_depth is None or node["depth"] <= max_depth: depth = node["depth"] if depth not in nodes_by_depth: nodes_by_depth[depth] = [] nodes_by_depth[depth].append((node_id, node)) for depth in nodes_by_depth: nodes_by_depth[depth].sort(key=lambda x: x[1].get('year') or 1400) for depth in sorted(nodes_by_depth.keys()): depth_nodes = nodes_by_depth[depth] for i, (node_id, node) in enumerate(depth_nodes): color = colors[depth % len(colors)] name = node.get("name", str(node_id)) year_str = f" ({node.get('year')})" if node.get('year') is not None else "" label = f"{name}{year_str}" year = node.get('year') or 1500 base_y = depth * 300 year_offset = (year - 1400) * 0.2 x_pos = i * 180 + (depth * 20) y_pos = base_y + year_offset ag_node = Node( id=str(node_id), label=label, size=25 if node_id == root_id else 20, color=color, title=f"Name: {name}\nYear: {node.get('year', 'N/A')}\nInstitution: {node.get('institution', 'N/A')}", x=x_pos, y=y_pos, font={"color": "white", "size": 12} ) nodes_list.append(ag_node) for advisor_id in node["advisors"]: if advisor_id in tree and (max_depth is None or tree[advisor_id]["depth"] <= max_depth): edge = Edge( source=str(advisor_id), target=str(node_id), color="#666666" ) edges_list.append(edge) return nodes_list, edges_list def tree_to_dot(graph: Dict[str, Any]) -> str: nodes = graph.get("nodes", {}) lines = [ "digraph G {", " rankdir=TB;", ' node [shape=box, style="rounded,filled", fillcolor=lightyellow];', ' edge [arrowhead=vee];' ] for node_id, node in nodes.items(): name = node.get("name", str(node_id)) year_str = f" ({node.get('year')})" if node.get('year') is not None else " (Year Unknown)" label = f"{name}{year_str}" tooltip = f"ID: {node_id}\\nName: {name}\\nYear: {node.get('year', 'N/A')}\\nInstitution: {node.get('institution', 'N/A')}" lines.append(f' "{node_id}" [label="{label}", tooltip="{tooltip}"];') for node_id, node in nodes.items(): for adv_id in node.get("advisors", []): if adv_id in nodes: lines.append(f' "{adv_id}" -> "{node_id}";') lines.append("}") return "\n".join(lines)