thearn commited on
Commit
2293b42
·
1 Parent(s): 25dabdb

better graph

Browse files
Files changed (2) hide show
  1. app.py +338 -12
  2. pyproject.toml +2 -0
app.py CHANGED
@@ -2,8 +2,11 @@ import streamlit as st
2
  import asyncio
3
  import websockets
4
  import json
5
- import platform
6
- from typing import Dict, Any, Optional, List, Literal, TypedDict, cast
 
 
 
7
 
8
  GGRAPHER_URI = "wss://ggrphr.davidalber.net"
9
 
@@ -63,7 +66,190 @@ async def get_graph(payload: RequestPayload, progress_cb=None) -> Dict[str, Any]
63
  else:
64
  continue
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def tree_to_dot(graph: Dict[str, Any]) -> str:
 
67
  nodes = graph.get("nodes", {})
68
  lines = [
69
  "digraph G {",
@@ -71,24 +257,52 @@ def tree_to_dot(graph: Dict[str, Any]) -> str:
71
  ' node [shape=box, style="rounded,filled", fillcolor=lightyellow];',
72
  ' edge [arrowhead=vee];'
73
  ]
74
- # Define nodes and their labels
75
  for node_id, node in nodes.items():
76
  name = node.get("name", str(node_id))
77
  year_str = f" ({node.get('year')})" if node.get('year') is not None else " (Year Unknown)"
78
  label = f"{name}{year_str}"
79
  tooltip = f"ID: {node_id}\\nName: {name}\\nYear: {node.get('year', 'N/A')}\\nInstitution: {node.get('institution', 'N/A')}"
80
  lines.append(f' "{node_id}" [label="{label}", tooltip="{tooltip}"];')
81
- # Define edges
82
  for node_id, node in nodes.items():
83
  for adv_id in node.get("advisors", []):
84
  if adv_id in nodes:
85
  lines.append(f' "{adv_id}" -> "{node_id}";')
 
86
  lines.append("}")
87
  return "\n".join(lines)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def main():
90
  st.title("Math Genealogy Ancestor Tree")
91
-
 
92
  mathematicians = [
93
  ("Tristan Hearn", 162833),
94
  ("Alexander Grothendieck", 31245),
@@ -100,13 +314,19 @@ def main():
100
  names = [f"{name} ({mid})" for name, mid in mathematicians]
101
  default_index = 0 # Tristan Hearn
102
 
103
- # initialize session state for mgp_id_str if not set
104
  if "mgp_id_str" not in st.session_state:
105
  st.session_state["mgp_id_str"] = str(mathematicians[default_index][1])
 
 
 
 
106
 
 
 
 
107
  mgp_id_str = st.text_input(
108
  "Enter MGP ID (integer):",
109
- value=st.session_state["mgp_id_str"],
110
  key="mgp_id_str",
111
  help="You can type a custom ID or use the selection below."
112
  )
@@ -124,30 +344,136 @@ def main():
124
  )
125
 
126
  progress_placeholder = st.empty()
127
- graph_placeholder = st.empty()
128
- run_btn = st.button("Show Ancestor Tree")
 
129
  if run_btn:
130
  mgp_id = get_id_from_input(st.session_state["mgp_id_str"])
131
  if mgp_id is None:
132
  st.error("Please enter a valid integer MGP ID.")
133
  return
 
134
  payload = make_payload(mgp_id)
135
  loop = asyncio.new_event_loop()
136
  asyncio.set_event_loop(loop)
 
137
  def progress_cb(progress):
138
  progress_placeholder.info(
139
  f"Queued: {progress['queued']} | Fetching: {progress['fetching']} | Done: {progress['done']}"
140
  )
 
141
  async def runner():
142
  graph = await get_graph(payload, progress_cb)
143
- dot = tree_to_dot(graph)
144
- graph_placeholder.graphviz_chart(dot)
 
145
  try:
146
  loop.run_until_complete(runner())
147
- progress_placeholder.success("Done!")
148
  except Exception as e:
149
  print(f"Error: {e}")
150
  progress_placeholder.error(f"Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  if __name__ == "__main__":
153
  main()
 
2
  import asyncio
3
  import websockets
4
  import json
5
+ import tempfile
6
+ import os
7
+ from typing import Dict, Any, Optional, List, Literal, TypedDict, cast, Set, Tuple
8
+ from streamlit_agraph import agraph, Node, Edge, Config
9
+ import pyvis.network as net
10
 
11
  GGRAPHER_URI = "wss://ggrphr.davidalber.net"
12
 
 
66
  else:
67
  continue
68
 
69
+ def build_tree_structure(graph: Dict[str, Any], root_id: int) -> Dict[int, Dict[str, Any]]:
70
+ """Build a hierarchical tree structure from the graph data."""
71
+ nodes = graph.get("nodes", {})
72
+ tree = {}
73
+
74
+ # calculate depth/generation for each node
75
+ def calculate_depths(node_id: int, visited: Set[int], depth: int = 0) -> Dict[int, int]:
76
+ if node_id in visited:
77
+ return {}
78
+
79
+ visited.add(node_id)
80
+ depths = {node_id: depth}
81
+
82
+ node = nodes.get(node_id, {})
83
+ for advisor_id in node.get("advisors", []):
84
+ if advisor_id in nodes:
85
+ advisor_depths = calculate_depths(advisor_id, visited, depth + 1)
86
+ depths.update(advisor_depths)
87
+
88
+ return depths
89
+
90
+ depths = calculate_depths(root_id, set())
91
+
92
+ # build tree structure with depth info
93
+ for node_id, node in nodes.items():
94
+ if node_id in depths:
95
+ tree[node_id] = {
96
+ **node,
97
+ "depth": depths[node_id],
98
+ "children": [],
99
+ "advisors": node.get("advisors", [])
100
+ }
101
+
102
+ # establish parent-child relationships
103
+ for node_id, node in tree.items():
104
+ for advisor_id in node["advisors"]:
105
+ if advisor_id in tree:
106
+ tree[advisor_id]["children"].append(node_id)
107
+
108
+ return tree
109
+
110
+ def create_hierarchical_view(graph: Dict[str, Any], root_id: int, max_depth: int = None) -> Tuple[List[Node], List[Edge]]:
111
+ """Create nodes and edges for hierarchical tree view."""
112
+ tree = build_tree_structure(graph, root_id)
113
+ nodes_list = []
114
+ edges_list = []
115
+
116
+ # color scheme by generation
117
+ colors = ["#ff6b6b", "#4ecdc4", "#45b7d1", "#96ceb4", "#feca57", "#ff9ff3", "#54a0ff"]
118
+
119
+ # group nodes by depth for positioning
120
+ nodes_by_depth = {}
121
+ for node_id, node in tree.items():
122
+ if max_depth is None or node["depth"] <= max_depth:
123
+ depth = node["depth"]
124
+ if depth not in nodes_by_depth:
125
+ nodes_by_depth[depth] = []
126
+ nodes_by_depth[depth].append((node_id, node))
127
+
128
+ # sort nodes within each depth by year (oldest first, recent last for lower position)
129
+ for depth in nodes_by_depth:
130
+ nodes_by_depth[depth].sort(key=lambda x: x[1].get('year') or 1400)
131
+
132
+ # create nodes with positioning hints
133
+ for depth in sorted(nodes_by_depth.keys()):
134
+ depth_nodes = nodes_by_depth[depth]
135
+
136
+ for i, (node_id, node) in enumerate(depth_nodes):
137
+ color = colors[depth % len(colors)]
138
+
139
+ name = node.get("name", str(node_id))
140
+ year_str = f" ({node.get('year')})" if node.get('year') is not None else ""
141
+ label = f"{name}{year_str}"
142
+
143
+ # calculate positioning based on year within generation
144
+ year = node.get('year') or 1500 # default to old year if missing
145
+ base_y = depth * 300 # spacing between generations
146
+
147
+ # position nodes with year-based offset within generation
148
+ # more recent years get higher y values (appear lower on screen)
149
+ year_offset = (year - 1400) * 0.2 # scale factor for year spacing
150
+ x_pos = i * 180 + (depth * 20) # slight x offset per depth to avoid overlap
151
+ y_pos = base_y + year_offset
152
+
153
+ # create streamlit-agraph node with positioning
154
+ ag_node = Node(
155
+ id=str(node_id),
156
+ label=label,
157
+ size=25 if node_id == root_id else 20,
158
+ color=color,
159
+ title=f"Name: {name}\nYear: {node.get('year', 'N/A')}\nInstitution: {node.get('institution', 'N/A')}",
160
+ x=x_pos,
161
+ y=y_pos
162
+ )
163
+ nodes_list.append(ag_node)
164
+
165
+ # create edges to advisors
166
+ for advisor_id in node["advisors"]:
167
+ if advisor_id in tree and (max_depth is None or tree[advisor_id]["depth"] <= max_depth):
168
+ edge = Edge(
169
+ source=str(advisor_id),
170
+ target=str(node_id),
171
+ color="#666666"
172
+ )
173
+ edges_list.append(edge)
174
+
175
+ return nodes_list, edges_list
176
+
177
+ def create_pyvis_network(graph: Dict[str, Any], root_id: int) -> str:
178
+ """Create an interactive network using pyvis."""
179
+ nodes = graph.get("nodes", {})
180
+
181
+ # create network
182
+ nt = net.Network(
183
+ height="600px",
184
+ width="100%",
185
+ bgcolor="#ffffff",
186
+ font_color="black",
187
+ directed=True
188
+ )
189
+
190
+ # configure physics
191
+ nt.set_options("""
192
+ var options = {
193
+ "physics": {
194
+ "enabled": true,
195
+ "hierarchicalRepulsion": {
196
+ "centralGravity": 0.3,
197
+ "springLength": 120,
198
+ "springConstant": 0.01,
199
+ "nodeDistance": 200,
200
+ "damping": 0.09
201
+ },
202
+ "solver": "hierarchicalRepulsion"
203
+ },
204
+ "layout": {
205
+ "hierarchical": {
206
+ "enabled": true,
207
+ "direction": "UD",
208
+ "sortMethod": "directed"
209
+ }
210
+ }
211
+ }
212
+ """)
213
+
214
+ # calculate depths for coloring
215
+ tree = build_tree_structure(graph, root_id)
216
+ colors = ["#ff6b6b", "#4ecdc4", "#45b7d1", "#96ceb4", "#feca57", "#ff9ff3", "#54a0ff"]
217
+
218
+ # add nodes
219
+ for node_id, node in nodes.items():
220
+ name = node.get("name", str(node_id))
221
+ year_str = f" ({node.get('year')})" if node.get('year') is not None else ""
222
+ label = f"{name}{year_str}"
223
+
224
+ depth = tree.get(node_id, {}).get("depth", 0)
225
+ color = colors[depth % len(colors)]
226
+
227
+ nt.add_node(
228
+ node_id,
229
+ label=label,
230
+ title=f"Name: {name}\nYear: {node.get('year', 'N/A')}\nInstitution: {node.get('institution', 'N/A')}",
231
+ color=color,
232
+ size=30 if node_id == root_id else 20
233
+ )
234
+
235
+ # add edges
236
+ for node_id, node in nodes.items():
237
+ for advisor_id in node.get("advisors", []):
238
+ if advisor_id in nodes:
239
+ nt.add_edge(advisor_id, node_id)
240
+
241
+ # save to temp file
242
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
243
+ nt.save_graph(temp_file.name)
244
+
245
+ with open(temp_file.name, 'r') as f:
246
+ html_content = f.read()
247
+
248
+ os.unlink(temp_file.name)
249
+ return html_content
250
+
251
  def tree_to_dot(graph: Dict[str, Any]) -> str:
252
+ """Original graphviz DOT format (kept as fallback)."""
253
  nodes = graph.get("nodes", {})
254
  lines = [
255
  "digraph G {",
 
257
  ' node [shape=box, style="rounded,filled", fillcolor=lightyellow];',
258
  ' edge [arrowhead=vee];'
259
  ]
260
+
261
  for node_id, node in nodes.items():
262
  name = node.get("name", str(node_id))
263
  year_str = f" ({node.get('year')})" if node.get('year') is not None else " (Year Unknown)"
264
  label = f"{name}{year_str}"
265
  tooltip = f"ID: {node_id}\\nName: {name}\\nYear: {node.get('year', 'N/A')}\\nInstitution: {node.get('institution', 'N/A')}"
266
  lines.append(f' "{node_id}" [label="{label}", tooltip="{tooltip}"];')
267
+
268
  for node_id, node in nodes.items():
269
  for adv_id in node.get("advisors", []):
270
  if adv_id in nodes:
271
  lines.append(f' "{adv_id}" -> "{node_id}";')
272
+
273
  lines.append("}")
274
  return "\n".join(lines)
275
 
276
+ def display_tree_summary(graph: Dict[str, Any], root_id: int):
277
+ """Display summary statistics about the tree."""
278
+ tree = build_tree_structure(graph, root_id)
279
+
280
+ if not tree:
281
+ return
282
+
283
+ max_depth = max(node["depth"] for node in tree.values()) if tree else 0
284
+ total_nodes = len(tree)
285
+
286
+ # count nodes by generation
287
+ depth_counts = {}
288
+ for node in tree.values():
289
+ depth = node["depth"]
290
+ depth_counts[depth] = depth_counts.get(depth, 0) + 1
291
+
292
+ col1, col2, col3 = st.columns(3)
293
+ with col1:
294
+ st.metric("Total Mathematicians", total_nodes)
295
+ with col2:
296
+ st.metric("Generations Back", max_depth)
297
+ with col3:
298
+ root_name = tree.get(root_id, {}).get("name", "Unknown")
299
+ st.metric("Root", root_name)
300
+
301
+
302
  def main():
303
  st.title("Math Genealogy Ancestor Tree")
304
+ st.write("Interactive visualization of academic advisor relationships from the Mathematics Genealogy Project")
305
+
306
  mathematicians = [
307
  ("Tristan Hearn", 162833),
308
  ("Alexander Grothendieck", 31245),
 
314
  names = [f"{name} ({mid})" for name, mid in mathematicians]
315
  default_index = 0 # Tristan Hearn
316
 
317
+ # initialize session state
318
  if "mgp_id_str" not in st.session_state:
319
  st.session_state["mgp_id_str"] = str(mathematicians[default_index][1])
320
+ if "graph_data" not in st.session_state:
321
+ st.session_state["graph_data"] = None
322
+ if "root_id" not in st.session_state:
323
+ st.session_state["root_id"] = None
324
 
325
+ # input section
326
+ st.subheader("Select Mathematician")
327
+
328
  mgp_id_str = st.text_input(
329
  "Enter MGP ID (integer):",
 
330
  key="mgp_id_str",
331
  help="You can type a custom ID or use the selection below."
332
  )
 
344
  )
345
 
346
  progress_placeholder = st.empty()
347
+
348
+ # fetch data
349
+ run_btn = st.button("Fetch Ancestor Tree", type="primary")
350
  if run_btn:
351
  mgp_id = get_id_from_input(st.session_state["mgp_id_str"])
352
  if mgp_id is None:
353
  st.error("Please enter a valid integer MGP ID.")
354
  return
355
+
356
  payload = make_payload(mgp_id)
357
  loop = asyncio.new_event_loop()
358
  asyncio.set_event_loop(loop)
359
+
360
  def progress_cb(progress):
361
  progress_placeholder.info(
362
  f"Queued: {progress['queued']} | Fetching: {progress['fetching']} | Done: {progress['done']}"
363
  )
364
+
365
  async def runner():
366
  graph = await get_graph(payload, progress_cb)
367
+ st.session_state["graph_data"] = graph
368
+ st.session_state["root_id"] = mgp_id
369
+
370
  try:
371
  loop.run_until_complete(runner())
372
+ progress_placeholder.success("Data fetched successfully!")
373
  except Exception as e:
374
  print(f"Error: {e}")
375
  progress_placeholder.error(f"Error: {e}")
376
+ return
377
+
378
+ # display visualizations if data is available
379
+ if st.session_state["graph_data"] is not None:
380
+ graph = st.session_state["graph_data"]
381
+ root_id = st.session_state["root_id"]
382
+
383
+ st.divider()
384
+
385
+ # show summary
386
+ display_tree_summary(graph, root_id)
387
+
388
+ st.divider()
389
+
390
+ # visualization options
391
+ st.subheader("Choose Visualization")
392
+
393
+ viz_option = st.radio(
394
+ "Select visualization type:",
395
+ ["Interactive Hierarchical Tree", "Interactive Network", "Traditional Graph (Graphviz)"],
396
+ help="Different views for exploring the genealogy tree"
397
+ )
398
+
399
+ if viz_option == "Interactive Hierarchical Tree":
400
+ st.write("**Hierarchical Tree View** - Best for exploring direct lineages")
401
+
402
+ # depth filter
403
+ tree = build_tree_structure(graph, root_id)
404
+ max_available_depth = max(node["depth"] for node in tree.values()) if tree else 0
405
+
406
+ if max_available_depth > 0:
407
+ depth_filter = st.slider(
408
+ "Show generations back:",
409
+ min_value=0,
410
+ max_value=max_available_depth,
411
+ value=min(3, max_available_depth),
412
+ help="Limit the number of generations to display for better readability"
413
+ )
414
+ else:
415
+ depth_filter = 0
416
+
417
+ # create hierarchical view
418
+ nodes_list, edges_list = create_hierarchical_view(graph, root_id, depth_filter)
419
+
420
+ if nodes_list:
421
+ config = Config(
422
+ width=800,
423
+ height=600,
424
+ directed=True,
425
+ physics=True,
426
+ hierarchical=True,
427
+ nodeHighlightBehavior=True,
428
+ highlightColor="#F7A7A6",
429
+ collapsible=False
430
+ )
431
+
432
+ agraph(nodes=nodes_list, edges=edges_list, config=config)
433
+ else:
434
+ st.warning("No data to display with current filters.")
435
+
436
+ elif viz_option == "Interactive Network":
437
+ st.write("**Interactive Network View** - Explore with zoom, pan, and physics simulation")
438
+
439
+ with st.spinner("Generating interactive network..."):
440
+ html_content = create_pyvis_network(graph, root_id)
441
+ st.components.v1.html(html_content, height=650)
442
+
443
+ else: # Traditional Graph
444
+ st.write("**Traditional Graph View** - Standard graphviz layout")
445
+ dot = tree_to_dot(graph)
446
+ st.graphviz_chart(dot)
447
+
448
+ # search functionality
449
+ st.divider()
450
+ st.subheader("Search Mathematicians")
451
+
452
+ nodes = graph.get("nodes", {})
453
+ search_term = st.text_input("Search by name:", placeholder="e.g., Gauss, Euler, Newton")
454
+
455
+ if search_term:
456
+ matches = []
457
+ for node_id, node in nodes.items():
458
+ name = node.get("name", "")
459
+ if search_term.lower() in name.lower():
460
+ year = node.get("year", "N/A")
461
+ institution = node.get("institution", "N/A")
462
+ matches.append({
463
+ "id": node_id,
464
+ "name": name,
465
+ "year": year,
466
+ "institution": institution
467
+ })
468
+
469
+ if matches:
470
+ st.write(f"Found {len(matches)} match(es):")
471
+ for match in matches[:10]: # limit to 10 results
472
+ st.write(f"• **{match['name']}** ({match['year']}) - {match['institution']} (ID: {match['id']})")
473
+ if len(matches) > 10:
474
+ st.write(f"... and {len(matches) - 10} more")
475
+ else:
476
+ st.write("No matches found.")
477
 
478
  if __name__ == "__main__":
479
  main()
pyproject.toml CHANGED
@@ -11,6 +11,8 @@ python = ">=3.9,<3.9.7 || >3.9.7,<4.0"
11
  streamlit = "^1.35.0"
12
  aiohttp = "^3.12.0"
13
  beautifulsoup4 = "^4.13.4"
 
 
14
 
15
  [tool.poetry.group.dev.dependencies]
16
  pytest = "^8.2.0"
 
11
  streamlit = "^1.35.0"
12
  aiohttp = "^3.12.0"
13
  beautifulsoup4 = "^4.13.4"
14
+ streamlit-agraph = "^0.0.45"
15
+ pyvis = "^0.3.2"
16
 
17
  [tool.poetry.group.dev.dependencies]
18
  pytest = "^8.2.0"