SorrelC commited on
Commit
bca4782
Β·
verified Β·
1 Parent(s): c83150a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +452 -186
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import networkx as nx
3
- from pyvis.network import Network
4
- import tempfile
5
- import os
6
 
7
  # Entity type colors
8
  ENTITY_COLORS = {
@@ -26,6 +25,11 @@ RELATIONSHIP_TYPES = [
26
  'knows',
27
  'related_to',
28
  'collaborates_with',
 
 
 
 
 
29
  'other'
30
  ]
31
 
@@ -36,20 +40,23 @@ class NetworkGraphBuilder:
36
 
37
  def add_entity(self, name, entity_type, record_id):
38
  """Add an entity to the collection"""
39
- if name.strip():
40
- self.entities.append({
41
- 'name': name.strip(),
42
- 'type': entity_type,
43
- 'record_id': record_id
44
- })
 
 
 
45
 
46
  def add_relationship(self, source, target, rel_type):
47
  """Add a relationship between entities"""
48
- if source and target and source != target:
49
  self.relationships.append({
50
  'source': source.strip(),
51
  'target': target.strip(),
52
- 'type': rel_type
53
  })
54
 
55
  def build_graph(self):
@@ -75,88 +82,153 @@ class NetworkGraphBuilder:
75
 
76
  return G
77
 
78
- def create_pyvis_graph(self, G):
79
- """Create interactive pyvis visualization"""
80
  if len(G.nodes) == 0:
81
  return None
82
 
83
- # Create pyvis network
84
- net = Network(height="600px", width="100%", bgcolor="#fafafa", font_color="#333")
85
- net.set_options("""
86
- {
87
- "physics": {
88
- "enabled": true,
89
- "barnesHut": {
90
- "gravitationalConstant": -8000,
91
- "springLength": 150,
92
- "springConstant": 0.04
93
- }
94
- },
95
- "nodes": {
96
- "font": {
97
- "size": 16
98
- }
99
- }
100
- }
101
- """)
102
 
103
- # Add nodes
104
- for node, data in G.nodes(data=True):
105
- entity_type = data.get('entity_type', 'UNKNOWN')
106
- color = ENTITY_COLORS.get(entity_type, '#CCCCCC')
107
-
108
- # Node size based on degree
109
- degree = G.degree(node)
110
- size = 20 + (degree * 5)
111
-
112
- # Create title (tooltip)
113
- connections = list(G.neighbors(node))
114
- title = f"{node}\nType: {entity_type}\nConnections: {len(connections)}"
115
- if connections:
116
- title += f"\nConnected to: {', '.join(connections[:5])}"
117
- if len(connections) > 5:
118
- title += f"... +{len(connections) - 5} more"
119
-
120
- net.add_node(node, label=node, color=color, size=size, title=title)
121
 
122
- # Add edges
123
  for edge in G.edges(data=True):
124
- rel_type = edge[2].get('relationship', '')
125
- net.add_edge(edge[0], edge[1], title=rel_type, label=rel_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- # Save to temporary file
128
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w')
129
- net.save_graph(temp_file.name)
130
- temp_file.close()
 
 
 
 
 
131
 
132
- # Read the HTML content
133
- with open(temp_file.name, 'r', encoding='utf-8') as f:
134
- html_content = f.read()
135
 
136
- # Clean up
137
- os.unlink(temp_file.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- return html_content
140
 
141
- def collect_entities_from_records(*args):
 
 
 
 
 
 
 
 
142
  """Collect all entities from the input fields"""
143
  builder = NetworkGraphBuilder()
144
 
145
- # Each record has 5 entity fields (person, location, event, org, date)
146
- num_records = 6
147
- fields_per_record = 5
 
 
 
 
 
 
148
 
149
- for i in range(num_records):
150
- record_id = i + 1
151
- base_idx = i * fields_per_record
152
-
153
- # Extract entities for this record
154
- person = args[base_idx] if base_idx < len(args) else ""
155
- location = args[base_idx + 1] if base_idx + 1 < len(args) else ""
156
- event = args[base_idx + 2] if base_idx + 2 < len(args) else ""
157
- org = args[base_idx + 3] if base_idx + 3 < len(args) else ""
158
- date = args[base_idx + 4] if base_idx + 4 < len(args) else ""
159
-
160
  if person:
161
  builder.add_entity(person, 'PERSON', record_id)
162
  if location:
@@ -169,44 +241,78 @@ def collect_entities_from_records(*args):
169
  builder.add_entity(date, 'DATE', record_id)
170
 
171
  # Create list of all entity names for relationship dropdowns
172
- entity_names = [e['name'] for e in builder.entities]
 
 
 
 
 
 
 
 
 
173
 
174
  # Create summary
175
  summary = f"""
176
- ### πŸ“Š Identified Entities
177
- - **Total entities:** {len(builder.entities)}
178
- - **People:** {sum(1 for e in builder.entities if e['type'] == 'PERSON')}
179
- - **Locations:** {sum(1 for e in builder.entities if e['type'] == 'LOCATION')}
180
- - **Events:** {sum(1 for e in builder.entities if e['type'] == 'EVENT')}
181
- - **Organizations:** {sum(1 for e in builder.entities if e['type'] == 'ORGANIZATION')}
182
- - **Dates:** {sum(1 for e in builder.entities if e['type'] == 'DATE')}
 
 
 
183
 
184
- Now define relationships between these entities on the right β†’
185
  """
186
 
187
- # Return summary and update all dropdowns
188
- dropdown_updates = [gr.update(choices=entity_names, value=None)] * 10
189
- return [summary] + dropdown_updates
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- def generate_network_graph(*args):
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  """Generate the network graph from all inputs"""
193
  try:
194
  builder = NetworkGraphBuilder()
195
 
196
- # Collect entities
197
- num_records = 6
198
- fields_per_record = 5
199
-
200
- for i in range(num_records):
201
- record_id = i + 1
202
- base_idx = i * fields_per_record
203
-
204
- person = args[base_idx] if base_idx < len(args) else ""
205
- location = args[base_idx + 1] if base_idx + 1 < len(args) else ""
206
- event = args[base_idx + 2] if base_idx + 2 < len(args) else ""
207
- org = args[base_idx + 3] if base_idx + 3 < len(args) else ""
208
- date = args[base_idx + 4] if base_idx + 4 < len(args) else ""
209
-
210
  if person:
211
  builder.add_entity(person, 'PERSON', record_id)
212
  if location:
@@ -218,16 +324,16 @@ def generate_network_graph(*args):
218
  if date:
219
  builder.add_entity(date, 'DATE', record_id)
220
 
221
- # Collect relationships
222
- relationship_start = 30
223
- num_relationships = 5
 
 
 
 
 
224
 
225
- for i in range(num_relationships):
226
- base_idx = relationship_start + (i * 3)
227
- source = args[base_idx] if base_idx < len(args) else None
228
- rel_type = args[base_idx + 1] if base_idx + 1 < len(args) else None
229
- target = args[base_idx + 2] if base_idx + 2 < len(args) else None
230
-
231
  if source and target:
232
  builder.add_relationship(source, target, rel_type)
233
 
@@ -235,61 +341,128 @@ def generate_network_graph(*args):
235
  G = builder.build_graph()
236
 
237
  if len(G.nodes) == 0:
238
- return None, "❌ **No entities to display.** Please enter entities in Step 1 and click 'Identify Entities' first."
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  # Create visualization
241
- html_graph = builder.create_pyvis_graph(G)
242
 
243
  # Create statistics
244
  stats = f"""
245
  ### πŸ“ˆ Network Statistics
246
- - **Nodes (Entities):** {G.number_of_nodes()}
247
- - **Edges (Relationships):** {G.number_of_edges()}
 
 
248
  """
249
 
250
  if len(G.edges) == 0:
251
- stats += "\n⚠️ **No relationships defined** - showing isolated nodes only.\n"
252
  else:
253
- stats += f"- **Network Density:** {nx.density(G):.3f}\n"
254
- stats += f"- **Average Connections per Node:** {sum(dict(G.degree()).values()) / G.number_of_nodes():.2f}\n"
 
 
255
 
256
  # Find most connected nodes
257
  degrees = dict(G.degree())
258
  top_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)[:3]
259
  stats += "\n**Most Connected Entities:**\n"
260
  for node, degree in top_nodes:
261
- stats += f"- {node}: {degree} connections\n"
262
 
263
- return html_graph, stats
264
 
265
  except Exception as e:
266
  import traceback
267
  error_trace = traceback.format_exc()
 
 
 
 
 
 
 
 
 
 
268
  error_msg = f"""
269
  ### ❌ Error Generating Graph
270
 
271
  {str(e)}
272
 
273
- **Technical details:**
 
 
274
  ```
275
  {error_trace}
276
  ```
 
277
  """
278
- return None, error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  def create_interface():
281
  with gr.Blocks(title="Basic Network Explorer", theme=gr.themes.Soft()) as demo:
282
  gr.Markdown("""
283
- # Basic Network Explorer
284
 
285
- Build interactive social network graphs by entering entities extracted through Named Entity Recognition (NER).
 
286
 
287
  ### How to use this tool:
288
- 1. **πŸ“ Enter entities** in the records on the left
289
- 2. **πŸ”— Click "Identify Entities"** to populate the dropdowns
290
- 3. **🀝 Define relationships** on the right
291
  4. **🎨 Click "Generate Network Graph"** to visualize
292
- 5. **πŸ‘οΈ Explore** - drag nodes, zoom, hover for details
293
  """)
294
 
295
  gr.HTML("""
@@ -298,80 +471,142 @@ def create_interface():
298
  </div>
299
  """)
300
 
 
301
  entity_inputs = []
302
 
 
 
 
303
  with gr.Row():
304
- # LEFT: Entity Inputs
 
 
 
 
 
 
305
  with gr.Column(scale=1):
306
- with gr.Accordion("πŸ“š Step 1: Enter Entities", open=True):
 
 
307
  for i in range(4):
308
  with gr.Group():
309
  gr.Markdown(f"**Record {i+1}**")
310
- person = gr.Textbox(label="πŸ‘€ Person", placeholder="e.g., Winston Churchill")
311
- location = gr.Textbox(label="πŸ“ Location", placeholder="e.g., London")
312
- event = gr.Textbox(label="πŸ“… Event", placeholder="e.g., Battle of Britain")
313
- org = gr.Textbox(label="🏒 Organization", placeholder="e.g., War Cabinet")
314
- date = gr.Textbox(label="πŸ—“οΈ Date", placeholder="e.g., 1940")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  entity_inputs.extend([person, location, event, org, date])
316
-
317
- with gr.Accordion("βž• Records 5-6", open=False):
318
- for i in range(4, 6):
319
- with gr.Group():
320
- gr.Markdown(f"**Record {i+1}**")
321
- person = gr.Textbox(label="πŸ‘€ Person")
322
- location = gr.Textbox(label="πŸ“ Location")
323
- event = gr.Textbox(label="πŸ“… Event")
324
- org = gr.Textbox(label="🏒 Organization")
325
- date = gr.Textbox(label="πŸ—“οΈ Date")
326
- entity_inputs.extend([person, location, event, org, date])
327
 
328
  collect_btn = gr.Button("πŸ” Identify Entities", variant="primary", size="lg")
329
  entity_summary = gr.Markdown()
330
 
331
- # RIGHT: Relationships & Graph
332
  with gr.Column(scale=1):
333
- with gr.Accordion("🀝 Step 2: Define Relationships", open=True):
334
- relationship_inputs = []
335
- for i in range(5):
336
- with gr.Row():
337
- source = gr.Dropdown(label="From", choices=[], scale=2)
338
- rel_type = gr.Dropdown(label="Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
339
- target = gr.Dropdown(label="To", choices=[], scale=2)
340
- relationship_inputs.extend([source, rel_type, target])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- generate_btn = gr.Button("πŸ” Generate Network Graph", variant="primary", size="lg")
343
 
344
- gr.HTML("<hr style='margin: 20px 0;'>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  network_stats = gr.Markdown()
346
- network_plot = gr.HTML(label="Interactive Network Graph")
347
-
348
- # Examples
349
- gr.Markdown("### πŸ’‘ Try an example:")
350
- gr.Examples(
351
- examples=[[
352
- "Winston Churchill", "London", "Battle of Britain", "War Cabinet", "1940",
353
- "Clement Attlee", "London", "Potsdam Conference", "Labour Party", "1945",
354
- "Field Marshal Montgomery", "North Africa", "Battle of El Alamein", "Eighth Army", "1942",
355
- "Winston Churchill", "Yalta", "Yalta Conference", "War Cabinet", "February 1945",
356
- "", "", "", "", "",
357
- "", "", "", "", "",
358
- "Winston Churchill", "works_with", "Clement Attlee",
359
- "Winston Churchill", "participated_in", "Battle of Britain",
360
- "Field Marshal Montgomery", "participated_in", "Battle of El Alamein",
361
- "", "", "",
362
- "", "", "",
363
- ]],
364
- inputs=entity_inputs + relationship_inputs,
365
- label="WWII Example"
366
  )
367
 
368
- # Wire up
369
  collect_btn.click(
370
  fn=collect_entities_from_records,
371
  inputs=entity_inputs,
372
- outputs=[entity_summary] + relationship_inputs[::3] + relationship_inputs[2::3]
 
 
 
 
 
 
 
373
  )
374
 
 
375
  all_inputs = entity_inputs + relationship_inputs
376
  generate_btn.click(
377
  fn=generate_network_graph,
@@ -379,16 +614,47 @@ def create_interface():
379
  outputs=[network_plot, network_stats]
380
  )
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  gr.HTML("""
383
- <hr style="margin: 40px 0;">
384
- <div style="text-align: center; color: #666; font-size: 14px;">
385
- <p>Basic Network Explorer | Bodleian Libraries (University of Oxford) Sassoon Research Fellowship</p>
386
- <p>Built with the aid of Claude Sonnet 4.5</p>
387
  </div>
388
  """)
389
 
390
  return demo
391
 
 
392
  if __name__ == "__main__":
393
  demo = create_interface()
394
  demo.launch()
 
1
  import gradio as gr
2
  import networkx as nx
3
+ import plotly.graph_objects as go
4
+ import numpy as np
 
5
 
6
  # Entity type colors
7
  ENTITY_COLORS = {
 
25
  'knows',
26
  'related_to',
27
  'collaborates_with',
28
+ 'married_to',
29
+ 'sibling_of',
30
+ 'parent_of',
31
+ 'lives_at',
32
+ 'wrote',
33
  'other'
34
  ]
35
 
 
40
 
41
  def add_entity(self, name, entity_type, record_id):
42
  """Add an entity to the collection"""
43
+ if name and name.strip():
44
+ # Avoid duplicates
45
+ existing = [e for e in self.entities if e['name'].lower() == name.strip().lower()]
46
+ if not existing:
47
+ self.entities.append({
48
+ 'name': name.strip(),
49
+ 'type': entity_type,
50
+ 'record_id': record_id
51
+ })
52
 
53
  def add_relationship(self, source, target, rel_type):
54
  """Add a relationship between entities"""
55
+ if source and target and source.strip() and target.strip() and source.strip() != target.strip():
56
  self.relationships.append({
57
  'source': source.strip(),
58
  'target': target.strip(),
59
+ 'type': rel_type if rel_type else 'related_to'
60
  })
61
 
62
  def build_graph(self):
 
82
 
83
  return G
84
 
85
+ def create_plotly_graph(self, G):
86
+ """Create interactive Plotly visualization"""
87
  if len(G.nodes) == 0:
88
  return None
89
 
90
+ # Use spring layout for positioning
91
+ pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # Create edge traces
94
+ edge_x = []
95
+ edge_y = []
96
+ edge_labels = []
97
+ edge_mid_x = []
98
+ edge_mid_y = []
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
100
  for edge in G.edges(data=True):
101
+ x0, y0 = pos[edge[0]]
102
+ x1, y1 = pos[edge[1]]
103
+ edge_x.extend([x0, x1, None])
104
+ edge_y.extend([y0, y1, None])
105
+
106
+ # Midpoint for edge label
107
+ mid_x = (x0 + x1) / 2
108
+ mid_y = (y0 + y1) / 2
109
+ edge_mid_x.append(mid_x)
110
+ edge_mid_y.append(mid_y)
111
+ edge_labels.append(edge[2].get('relationship', ''))
112
+
113
+ edge_trace = go.Scatter(
114
+ x=edge_x, y=edge_y,
115
+ line=dict(width=2, color='#888'),
116
+ hoverinfo='none',
117
+ mode='lines'
118
+ )
119
 
120
+ # Edge labels trace
121
+ edge_label_trace = go.Scatter(
122
+ x=edge_mid_x, y=edge_mid_y,
123
+ mode='text',
124
+ text=edge_labels,
125
+ textposition='middle center',
126
+ textfont=dict(size=10, color='#666'),
127
+ hoverinfo='none'
128
+ )
129
 
130
+ # Create node traces - one per entity type for legend
131
+ node_traces = []
 
132
 
133
+ for entity_type, color in ENTITY_COLORS.items():
134
+ nodes_of_type = [n for n, d in G.nodes(data=True) if d.get('entity_type') == entity_type]
135
+ if not nodes_of_type:
136
+ continue
137
+
138
+ node_x = []
139
+ node_y = []
140
+ node_text = []
141
+ node_hover = []
142
+ node_sizes = []
143
+
144
+ for node in nodes_of_type:
145
+ x, y = pos[node]
146
+ node_x.append(x)
147
+ node_y.append(y)
148
+ node_text.append(node)
149
+
150
+ # Calculate connections
151
+ connections = list(G.neighbors(node))
152
+ degree = len(connections)
153
+ node_sizes.append(30 + (degree * 10))
154
+
155
+ # Hover text
156
+ hover = f"<b>{node}</b><br>Type: {entity_type}<br>Connections: {degree}"
157
+ if connections:
158
+ hover += f"<br>Connected to: {', '.join(connections[:5])}"
159
+ if len(connections) > 5:
160
+ hover += f"<br>... +{len(connections) - 5} more"
161
+ node_hover.append(hover)
162
+
163
+ node_trace = go.Scatter(
164
+ x=node_x, y=node_y,
165
+ mode='markers+text',
166
+ name=entity_type,
167
+ text=node_text,
168
+ textposition='top center',
169
+ textfont=dict(size=12, color='#333'),
170
+ hoverinfo='text',
171
+ hovertext=node_hover,
172
+ marker=dict(
173
+ size=node_sizes,
174
+ color=color,
175
+ line=dict(width=2, color='white'),
176
+ symbol='circle'
177
+ )
178
+ )
179
+ node_traces.append(node_trace)
180
+
181
+ # Create figure
182
+ fig = go.Figure(
183
+ data=[edge_trace, edge_label_trace] + node_traces,
184
+ layout=go.Layout(
185
+ title=dict(
186
+ text='Interactive Network Graph',
187
+ font=dict(size=20)
188
+ ),
189
+ showlegend=True,
190
+ legend=dict(
191
+ yanchor="top",
192
+ y=0.99,
193
+ xanchor="left",
194
+ x=0.01,
195
+ bgcolor="rgba(255,255,255,0.8)"
196
+ ),
197
+ hovermode='closest',
198
+ margin=dict(b=20, l=5, r=5, t=50),
199
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
200
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
201
+ plot_bgcolor='#fafafa',
202
+ paper_bgcolor='#fafafa',
203
+ height=600
204
+ )
205
+ )
206
 
207
+ return fig
208
 
209
+
210
+ def collect_entities_from_records(
211
+ p1, l1, e1, o1, d1,
212
+ p2, l2, e2, o2, d2,
213
+ p3, l3, e3, o3, d3,
214
+ p4, l4, e4, o4, d4,
215
+ p5, l5, e5, o5, d5,
216
+ p6, l6, e6, o6, d6
217
+ ):
218
  """Collect all entities from the input fields"""
219
  builder = NetworkGraphBuilder()
220
 
221
+ # Process each record
222
+ records = [
223
+ (p1, l1, e1, o1, d1),
224
+ (p2, l2, e2, o2, d2),
225
+ (p3, l3, e3, o3, d3),
226
+ (p4, l4, e4, o4, d4),
227
+ (p5, l5, e5, o5, d5),
228
+ (p6, l6, e6, o6, d6),
229
+ ]
230
 
231
+ for record_id, (person, location, event, org, date) in enumerate(records, 1):
 
 
 
 
 
 
 
 
 
 
232
  if person:
233
  builder.add_entity(person, 'PERSON', record_id)
234
  if location:
 
241
  builder.add_entity(date, 'DATE', record_id)
242
 
243
  # Create list of all entity names for relationship dropdowns
244
+ entity_names = sorted([e['name'] for e in builder.entities])
245
+
246
+ # Count by type
247
+ counts = {
248
+ 'PERSON': sum(1 for e in builder.entities if e['type'] == 'PERSON'),
249
+ 'LOCATION': sum(1 for e in builder.entities if e['type'] == 'LOCATION'),
250
+ 'EVENT': sum(1 for e in builder.entities if e['type'] == 'EVENT'),
251
+ 'ORGANIZATION': sum(1 for e in builder.entities if e['type'] == 'ORGANIZATION'),
252
+ 'DATE': sum(1 for e in builder.entities if e['type'] == 'DATE'),
253
+ }
254
 
255
  # Create summary
256
  summary = f"""
257
+ ### πŸ“Š Identified Entities ({len(builder.entities)} total)
258
+ | Type | Count |
259
+ |------|-------|
260
+ | πŸ‘€ People | {counts['PERSON']} |
261
+ | πŸ“ Locations | {counts['LOCATION']} |
262
+ | πŸ“… Events | {counts['EVENT']} |
263
+ | 🏒 Organizations | {counts['ORGANIZATION']} |
264
+ | πŸ—“οΈ Dates | {counts['DATE']} |
265
+
266
+ **Entities found:** {', '.join(entity_names) if entity_names else 'None'}
267
 
268
+ ➑️ Now define relationships between these entities below
269
  """
270
 
271
+ # Return summary and update all 10 dropdowns (5 source + 5 target)
272
+ # Each dropdown gets updated with the entity names
273
+ return (
274
+ summary,
275
+ gr.update(choices=entity_names, value=None), # source 1
276
+ gr.update(choices=entity_names, value=None), # target 1
277
+ gr.update(choices=entity_names, value=None), # source 2
278
+ gr.update(choices=entity_names, value=None), # target 2
279
+ gr.update(choices=entity_names, value=None), # source 3
280
+ gr.update(choices=entity_names, value=None), # target 3
281
+ gr.update(choices=entity_names, value=None), # source 4
282
+ gr.update(choices=entity_names, value=None), # target 4
283
+ gr.update(choices=entity_names, value=None), # source 5
284
+ gr.update(choices=entity_names, value=None), # target 5
285
+ )
286
 
287
+
288
+ def generate_network_graph(
289
+ p1, l1, e1, o1, d1,
290
+ p2, l2, e2, o2, d2,
291
+ p3, l3, e3, o3, d3,
292
+ p4, l4, e4, o4, d4,
293
+ p5, l5, e5, o5, d5,
294
+ p6, l6, e6, o6, d6,
295
+ src1, rel1, tgt1,
296
+ src2, rel2, tgt2,
297
+ src3, rel3, tgt3,
298
+ src4, rel4, tgt4,
299
+ src5, rel5, tgt5
300
+ ):
301
  """Generate the network graph from all inputs"""
302
  try:
303
  builder = NetworkGraphBuilder()
304
 
305
+ # Process each record
306
+ records = [
307
+ (p1, l1, e1, o1, d1),
308
+ (p2, l2, e2, o2, d2),
309
+ (p3, l3, e3, o3, d3),
310
+ (p4, l4, e4, o4, d4),
311
+ (p5, l5, e5, o5, d5),
312
+ (p6, l6, e6, o6, d6),
313
+ ]
314
+
315
+ for record_id, (person, location, event, org, date) in enumerate(records, 1):
 
 
 
316
  if person:
317
  builder.add_entity(person, 'PERSON', record_id)
318
  if location:
 
324
  if date:
325
  builder.add_entity(date, 'DATE', record_id)
326
 
327
+ # Process relationships
328
+ relationships = [
329
+ (src1, rel1, tgt1),
330
+ (src2, rel2, tgt2),
331
+ (src3, rel3, tgt3),
332
+ (src4, rel4, tgt4),
333
+ (src5, rel5, tgt5),
334
+ ]
335
 
336
+ for source, rel_type, target in relationships:
 
 
 
 
 
337
  if source and target:
338
  builder.add_relationship(source, target, rel_type)
339
 
 
341
  G = builder.build_graph()
342
 
343
  if len(G.nodes) == 0:
344
+ empty_fig = go.Figure()
345
+ empty_fig.add_annotation(
346
+ text="No entities to display.<br>Please enter entities in Step 1 and click 'Identify Entities' first.",
347
+ xref="paper", yref="paper",
348
+ x=0.5, y=0.5, showarrow=False,
349
+ font=dict(size=16, color="#666")
350
+ )
351
+ empty_fig.update_layout(
352
+ height=400,
353
+ plot_bgcolor='#fafafa',
354
+ paper_bgcolor='#fafafa'
355
+ )
356
+ return empty_fig, "❌ **No entities to display.** Please enter entities in Step 1 first."
357
 
358
  # Create visualization
359
+ fig = builder.create_plotly_graph(G)
360
 
361
  # Create statistics
362
  stats = f"""
363
  ### πŸ“ˆ Network Statistics
364
+ | Metric | Value |
365
+ |--------|-------|
366
+ | **Nodes (Entities)** | {G.number_of_nodes()} |
367
+ | **Edges (Relationships)** | {G.number_of_edges()} |
368
  """
369
 
370
  if len(G.edges) == 0:
371
+ stats += "\n⚠️ **No relationships defined** - showing isolated nodes. Add relationships in Step 2!\n"
372
  else:
373
+ density = nx.density(G)
374
+ avg_degree = sum(dict(G.degree()).values()) / G.number_of_nodes()
375
+ stats += f"| **Network Density** | {density:.3f} |\n"
376
+ stats += f"| **Avg. Connections** | {avg_degree:.2f} |\n"
377
 
378
  # Find most connected nodes
379
  degrees = dict(G.degree())
380
  top_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)[:3]
381
  stats += "\n**Most Connected Entities:**\n"
382
  for node, degree in top_nodes:
383
+ stats += f"- **{node}**: {degree} connections\n"
384
 
385
+ return fig, stats
386
 
387
  except Exception as e:
388
  import traceback
389
  error_trace = traceback.format_exc()
390
+
391
+ error_fig = go.Figure()
392
+ error_fig.add_annotation(
393
+ text=f"Error: {str(e)}",
394
+ xref="paper", yref="paper",
395
+ x=0.5, y=0.5, showarrow=False,
396
+ font=dict(size=14, color="red")
397
+ )
398
+ error_fig.update_layout(height=300, plot_bgcolor='#fafafa', paper_bgcolor='#fafafa')
399
+
400
  error_msg = f"""
401
  ### ❌ Error Generating Graph
402
 
403
  {str(e)}
404
 
405
+ <details>
406
+ <summary>Technical details</summary>
407
+
408
  ```
409
  {error_trace}
410
  ```
411
+ </details>
412
  """
413
+ return error_fig, error_msg
414
+
415
+
416
+ def load_austen_example():
417
+ """Load the Jane Austen Pride and Prejudice example"""
418
+ return (
419
+ # Record 1
420
+ "Elizabeth Bennet", "Longbourn", "Meryton Ball", "Bennet Family", "1811",
421
+ # Record 2
422
+ "Mr. Darcy", "Pemberley", "Netherfield Ball", "Darcy Estate", "1811",
423
+ # Record 3
424
+ "Jane Bennet", "Netherfield", "", "Bennet Family", "",
425
+ # Record 4
426
+ "Mr. Bingley", "Netherfield", "", "", "",
427
+ # Record 5
428
+ "Mr. Wickham", "Meryton", "", "Militia", "",
429
+ # Record 6
430
+ "Charlotte Lucas", "Hunsford", "", "", "",
431
+ )
432
+
433
+
434
+ def load_wwii_example():
435
+ """Load a WWII history example"""
436
+ return (
437
+ # Record 1
438
+ "Winston Churchill", "London", "Battle of Britain", "War Cabinet", "1940",
439
+ # Record 2
440
+ "Clement Attlee", "London", "Potsdam Conference", "Labour Party", "1945",
441
+ # Record 3
442
+ "Field Marshal Montgomery", "North Africa", "Battle of El Alamein", "Eighth Army", "1942",
443
+ # Record 4
444
+ "Franklin D. Roosevelt", "Washington D.C.", "D-Day", "Allied Forces", "1944",
445
+ # Record 5
446
+ "", "", "", "", "",
447
+ # Record 6
448
+ "", "", "", "", "",
449
+ )
450
+
451
 
452
  def create_interface():
453
  with gr.Blocks(title="Basic Network Explorer", theme=gr.themes.Soft()) as demo:
454
  gr.Markdown("""
455
+ # πŸ•ΈοΈ Basic Network Explorer
456
 
457
+ Build interactive network graphs by entering entities extracted through Named Entity Recognition (NER).
458
+ Explore relationships between people, places, events, organizations and dates.
459
 
460
  ### How to use this tool:
461
+ 1. **πŸ“ Enter entities** in the records below (or load an example)
462
+ 2. **πŸ” Click "Identify Entities"** to collect and list all entities
463
+ 3. **🀝 Define relationships** between entities using the dropdowns
464
  4. **🎨 Click "Generate Network Graph"** to visualize
465
+ 5. **πŸ‘οΈ Explore** - hover over nodes for details, zoom and pan the graph
466
  """)
467
 
468
  gr.HTML("""
 
471
  </div>
472
  """)
473
 
474
+ # Store all entity input components
475
  entity_inputs = []
476
 
477
+ # Example buttons
478
+ with gr.Row():
479
+ gr.Markdown("### πŸ’‘ Quick Start - Load an Example:")
480
  with gr.Row():
481
+ austen_btn = gr.Button("πŸ“š Jane Austen (Pride & Prejudice)", variant="secondary")
482
+ wwii_btn = gr.Button("βš”οΈ WWII History", variant="secondary")
483
+
484
+ gr.HTML("<hr style='margin: 20px 0;'>")
485
+
486
+ with gr.Row():
487
+ # LEFT COLUMN: Entity Inputs
488
  with gr.Column(scale=1):
489
+ gr.Markdown("## πŸ“š Step 1: Enter Entities")
490
+
491
+ with gr.Accordion("Records 1-4", open=True):
492
  for i in range(4):
493
  with gr.Group():
494
  gr.Markdown(f"**Record {i+1}**")
495
+ with gr.Row():
496
+ person = gr.Textbox(label="πŸ‘€ Person", placeholder="e.g., Elizabeth Bennet", scale=1)
497
+ location = gr.Textbox(label="πŸ“ Location", placeholder="e.g., Longbourn", scale=1)
498
+ with gr.Row():
499
+ event = gr.Textbox(label="πŸ“… Event", placeholder="e.g., Meryton Ball", scale=1)
500
+ org = gr.Textbox(label="🏒 Organization", placeholder="e.g., Bennet Family", scale=1)
501
+ date = gr.Textbox(label="πŸ—“οΈ Date", placeholder="e.g., 1811")
502
+ entity_inputs.extend([person, location, event, org, date])
503
+
504
+ with gr.Accordion("Records 5-6 (Optional)", open=False):
505
+ for i in range(4, 6):
506
+ with gr.Group():
507
+ gr.Markdown(f"**Record {i+1}**")
508
+ with gr.Row():
509
+ person = gr.Textbox(label="πŸ‘€ Person", scale=1)
510
+ location = gr.Textbox(label="πŸ“ Location", scale=1)
511
+ with gr.Row():
512
+ event = gr.Textbox(label="πŸ“… Event", scale=1)
513
+ org = gr.Textbox(label="🏒 Organization", scale=1)
514
+ date = gr.Textbox(label="πŸ—“οΈ Date")
515
  entity_inputs.extend([person, location, event, org, date])
 
 
 
 
 
 
 
 
 
 
 
516
 
517
  collect_btn = gr.Button("πŸ” Identify Entities", variant="primary", size="lg")
518
  entity_summary = gr.Markdown()
519
 
520
+ # RIGHT COLUMN: Relationships
521
  with gr.Column(scale=1):
522
+ gr.Markdown("## 🀝 Step 2: Define Relationships")
523
+ gr.Markdown("*First click 'Identify Entities' to populate the dropdowns*")
524
+
525
+ # Create relationship inputs with explicit variable names
526
+ source1 = gr.Dropdown(label="From", choices=[], scale=2)
527
+ with gr.Row():
528
+ rel_type1 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
529
+ target1 = gr.Dropdown(label="To", choices=[], scale=2)
530
+
531
+ gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
532
+
533
+ source2 = gr.Dropdown(label="From", choices=[], scale=2)
534
+ with gr.Row():
535
+ rel_type2 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
536
+ target2 = gr.Dropdown(label="To", choices=[], scale=2)
537
+
538
+ gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
539
+
540
+ source3 = gr.Dropdown(label="From", choices=[], scale=2)
541
+ with gr.Row():
542
+ rel_type3 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
543
+ target3 = gr.Dropdown(label="To", choices=[], scale=2)
544
 
545
+ gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
546
 
547
+ source4 = gr.Dropdown(label="From", choices=[], scale=2)
548
+ with gr.Row():
549
+ rel_type4 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
550
+ target4 = gr.Dropdown(label="To", choices=[], scale=2)
551
+
552
+ gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
553
+
554
+ source5 = gr.Dropdown(label="From", choices=[], scale=2)
555
+ with gr.Row():
556
+ rel_type5 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
557
+ target5 = gr.Dropdown(label="To", choices=[], scale=2)
558
+
559
+ # Collect relationship inputs
560
+ relationship_inputs = [
561
+ source1, rel_type1, target1,
562
+ source2, rel_type2, target2,
563
+ source3, rel_type3, target3,
564
+ source4, rel_type4, target4,
565
+ source5, rel_type5, target5
566
+ ]
567
+
568
+ gr.HTML("<hr style='margin: 30px 0;'>")
569
+
570
+ # Generate button
571
+ generate_btn = gr.Button("🎨 Generate Network Graph", variant="primary", size="lg")
572
+
573
+ # Output section
574
+ gr.Markdown("## πŸ“Š Step 3: View Results")
575
+
576
+ with gr.Row():
577
+ with gr.Column(scale=2):
578
+ network_plot = gr.Plot(label="Interactive Network Graph")
579
+ with gr.Column(scale=1):
580
  network_stats = gr.Markdown()
581
+
582
+ # Wire up the example buttons
583
+ austen_btn.click(
584
+ fn=load_austen_example,
585
+ inputs=[],
586
+ outputs=entity_inputs
587
+ )
588
+
589
+ wwii_btn.click(
590
+ fn=load_wwii_example,
591
+ inputs=[],
592
+ outputs=entity_inputs
 
 
 
 
 
 
 
 
593
  )
594
 
595
+ # Wire up collect entities button
596
  collect_btn.click(
597
  fn=collect_entities_from_records,
598
  inputs=entity_inputs,
599
+ outputs=[
600
+ entity_summary,
601
+ source1, target1,
602
+ source2, target2,
603
+ source3, target3,
604
+ source4, target4,
605
+ source5, target5
606
+ ]
607
  )
608
 
609
+ # Wire up generate graph button
610
  all_inputs = entity_inputs + relationship_inputs
611
  generate_btn.click(
612
  fn=generate_network_graph,
 
614
  outputs=[network_plot, network_stats]
615
  )
616
 
617
+ # Color legend
618
+ gr.HTML("""
619
+ <div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; margin-top: 20px;">
620
+ <h4 style="margin-top: 0;">🎨 Entity Color Legend</h4>
621
+ <div style="display: flex; flex-wrap: wrap; gap: 15px; align-items: center;">
622
+ <span style="display: flex; align-items: center; gap: 5px;">
623
+ <span style="width: 20px; height: 20px; border-radius: 50%; background-color: #00B894; display: inline-block;"></span>
624
+ <strong>Person</strong>
625
+ </span>
626
+ <span style="display: flex; align-items: center; gap: 5px;">
627
+ <span style="width: 20px; height: 20px; border-radius: 50%; background-color: #A0E7E5; display: inline-block;"></span>
628
+ <strong>Location</strong>
629
+ </span>
630
+ <span style="display: flex; align-items: center; gap: 5px;">
631
+ <span style="width: 20px; height: 20px; border-radius: 50%; background-color: #4ECDC4; display: inline-block;"></span>
632
+ <strong>Event</strong>
633
+ </span>
634
+ <span style="display: flex; align-items: center; gap: 5px;">
635
+ <span style="width: 20px; height: 20px; border-radius: 50%; background-color: #55A3FF; display: inline-block;"></span>
636
+ <strong>Organization</strong>
637
+ </span>
638
+ <span style="display: flex; align-items: center; gap: 5px;">
639
+ <span style="width: 20px; height: 20px; border-radius: 50%; background-color: #FF6B6B; display: inline-block;"></span>
640
+ <strong>Date</strong>
641
+ </span>
642
+ </div>
643
+ </div>
644
+ """)
645
+
646
+ # Footer
647
  gr.HTML("""
648
+ <hr style="margin: 40px 0 20px 0;">
649
+ <div style="text-align: center; color: #666; font-size: 14px; padding: 20px;">
650
+ <p><strong>Basic Network Explorer</strong> | Bodleian Libraries (University of Oxford) Sassoon Research Fellowship</p>
651
+ <p>Built with the aid of Claude</p>
652
  </div>
653
  """)
654
 
655
  return demo
656
 
657
+
658
  if __name__ == "__main__":
659
  demo = create_interface()
660
  demo.launch()