math_roots / src /graph.py
thearn's picture
refactor
121ffc2
from typing import Dict, Any, Set, Tuple, List
from streamlit_agraph import Node, Edge
def build_tree_structure(graph: Dict[str, Any], root_id: int) -> Dict[int, Dict[str, Any]]:
nodes = graph.get("nodes", {})
tree = {}
def calculate_depths(node_id: int, visited: Set[int], depth: int = 0) -> Dict[int, int]:
if node_id in visited:
return {}
visited.add(node_id)
depths = {node_id: depth}
node = nodes.get(node_id, {})
for advisor_id in node.get("advisors", []):
if advisor_id in nodes:
advisor_depths = calculate_depths(advisor_id, visited, depth + 1)
depths.update(advisor_depths)
return depths
depths = calculate_depths(root_id, set())
for node_id, node in nodes.items():
if node_id in depths:
tree[node_id] = {
**node,
"depth": depths[node_id],
"children": [],
"advisors": node.get("advisors", [])
}
for node_id, node in tree.items():
for advisor_id in node["advisors"]:
if advisor_id in tree:
tree[advisor_id]["children"].append(node_id)
return tree
def create_hierarchical_view(
graph: Dict[str, Any],
root_id: int,
max_depth: int = None
) -> Tuple[List[Node], List[Edge]]:
tree = build_tree_structure(graph, root_id)
nodes_list = []
edges_list = []
colors = ["#ff6b6b", "#4ecdc4", "#45b7d1", "#96ceb4", "#feca57", "#ff9ff3", "#54a0ff"]
nodes_by_depth = {}
for node_id, node in tree.items():
if max_depth is None or node["depth"] <= max_depth:
depth = node["depth"]
if depth not in nodes_by_depth:
nodes_by_depth[depth] = []
nodes_by_depth[depth].append((node_id, node))
for depth in nodes_by_depth:
nodes_by_depth[depth].sort(key=lambda x: x[1].get('year') or 1400)
for depth in sorted(nodes_by_depth.keys()):
depth_nodes = nodes_by_depth[depth]
for i, (node_id, node) in enumerate(depth_nodes):
color = colors[depth % len(colors)]
name = node.get("name", str(node_id))
year_str = f" ({node.get('year')})" if node.get('year') is not None else ""
label = f"{name}{year_str}"
year = node.get('year') or 1500
base_y = depth * 300
year_offset = (year - 1400) * 0.2
x_pos = i * 180 + (depth * 20)
y_pos = base_y + year_offset
ag_node = Node(
id=str(node_id),
label=label,
size=25 if node_id == root_id else 20,
color=color,
title=f"Name: {name}\nYear: {node.get('year', 'N/A')}\nInstitution: {node.get('institution', 'N/A')}",
x=x_pos,
y=y_pos,
font={"color": "white", "size": 12}
)
nodes_list.append(ag_node)
for advisor_id in node["advisors"]:
if advisor_id in tree and (max_depth is None or tree[advisor_id]["depth"] <= max_depth):
edge = Edge(
source=str(advisor_id),
target=str(node_id),
color="#666666"
)
edges_list.append(edge)
return nodes_list, edges_list
def tree_to_dot(graph: Dict[str, Any]) -> str:
nodes = graph.get("nodes", {})
lines = [
"digraph G {",
" rankdir=TB;",
' node [shape=box, style="rounded,filled", fillcolor=lightyellow];',
' edge [arrowhead=vee];'
]
for node_id, node in nodes.items():
name = node.get("name", str(node_id))
year_str = f" ({node.get('year')})" if node.get('year') is not None else " (Year Unknown)"
label = f"{name}{year_str}"
tooltip = f"ID: {node_id}\\nName: {name}\\nYear: {node.get('year', 'N/A')}\\nInstitution: {node.get('institution', 'N/A')}"
lines.append(f' "{node_id}" [label="{label}", tooltip="{tooltip}"];')
for node_id, node in nodes.items():
for adv_id in node.get("advisors", []):
if adv_id in nodes:
lines.append(f' "{adv_id}" -> "{node_id}";')
lines.append("}")
return "\n".join(lines)