SorrelC's picture
Update app.py
bca4782 verified
raw
history blame
24.8 kB
import gradio as gr
import networkx as nx
import plotly.graph_objects as go
import numpy as np
# Entity type colors
ENTITY_COLORS = {
'PERSON': '#00B894', # Green
'LOCATION': '#A0E7E5', # Light Cyan
'EVENT': '#4ECDC4', # Teal
'ORGANIZATION': '#55A3FF', # Light Blue
'DATE': '#FF6B6B' # Red
}
# Relationship types for dropdown
RELATIONSHIP_TYPES = [
'works_with',
'located_in',
'participated_in',
'member_of',
'occurred_at',
'employed_by',
'founded',
'attended',
'knows',
'related_to',
'collaborates_with',
'married_to',
'sibling_of',
'parent_of',
'lives_at',
'wrote',
'other'
]
class NetworkGraphBuilder:
def __init__(self):
self.entities = []
self.relationships = []
def add_entity(self, name, entity_type, record_id):
"""Add an entity to the collection"""
if name and name.strip():
# Avoid duplicates
existing = [e for e in self.entities if e['name'].lower() == name.strip().lower()]
if not existing:
self.entities.append({
'name': name.strip(),
'type': entity_type,
'record_id': record_id
})
def add_relationship(self, source, target, rel_type):
"""Add a relationship between entities"""
if source and target and source.strip() and target.strip() and source.strip() != target.strip():
self.relationships.append({
'source': source.strip(),
'target': target.strip(),
'type': rel_type if rel_type else 'related_to'
})
def build_graph(self):
"""Build NetworkX graph from entities and relationships"""
G = nx.Graph()
# Add nodes with attributes
for entity in self.entities:
G.add_node(
entity['name'],
entity_type=entity['type'],
record_id=entity['record_id']
)
# Add edges
for rel in self.relationships:
if rel['source'] in G.nodes and rel['target'] in G.nodes:
G.add_edge(
rel['source'],
rel['target'],
relationship=rel['type']
)
return G
def create_plotly_graph(self, G):
"""Create interactive Plotly visualization"""
if len(G.nodes) == 0:
return None
# Use spring layout for positioning
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
# Create edge traces
edge_x = []
edge_y = []
edge_labels = []
edge_mid_x = []
edge_mid_y = []
for edge in G.edges(data=True):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
# Midpoint for edge label
mid_x = (x0 + x1) / 2
mid_y = (y0 + y1) / 2
edge_mid_x.append(mid_x)
edge_mid_y.append(mid_y)
edge_labels.append(edge[2].get('relationship', ''))
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='#888'),
hoverinfo='none',
mode='lines'
)
# Edge labels trace
edge_label_trace = go.Scatter(
x=edge_mid_x, y=edge_mid_y,
mode='text',
text=edge_labels,
textposition='middle center',
textfont=dict(size=10, color='#666'),
hoverinfo='none'
)
# Create node traces - one per entity type for legend
node_traces = []
for entity_type, color in ENTITY_COLORS.items():
nodes_of_type = [n for n, d in G.nodes(data=True) if d.get('entity_type') == entity_type]
if not nodes_of_type:
continue
node_x = []
node_y = []
node_text = []
node_hover = []
node_sizes = []
for node in nodes_of_type:
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_text.append(node)
# Calculate connections
connections = list(G.neighbors(node))
degree = len(connections)
node_sizes.append(30 + (degree * 10))
# Hover text
hover = f"<b>{node}</b><br>Type: {entity_type}<br>Connections: {degree}"
if connections:
hover += f"<br>Connected to: {', '.join(connections[:5])}"
if len(connections) > 5:
hover += f"<br>... +{len(connections) - 5} more"
node_hover.append(hover)
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
name=entity_type,
text=node_text,
textposition='top center',
textfont=dict(size=12, color='#333'),
hoverinfo='text',
hovertext=node_hover,
marker=dict(
size=node_sizes,
color=color,
line=dict(width=2, color='white'),
symbol='circle'
)
)
node_traces.append(node_trace)
# Create figure
fig = go.Figure(
data=[edge_trace, edge_label_trace] + node_traces,
layout=go.Layout(
title=dict(
text='Interactive Network Graph',
font=dict(size=20)
),
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255,255,255,0.8)"
),
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=50),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='#fafafa',
paper_bgcolor='#fafafa',
height=600
)
)
return fig
def collect_entities_from_records(
p1, l1, e1, o1, d1,
p2, l2, e2, o2, d2,
p3, l3, e3, o3, d3,
p4, l4, e4, o4, d4,
p5, l5, e5, o5, d5,
p6, l6, e6, o6, d6
):
"""Collect all entities from the input fields"""
builder = NetworkGraphBuilder()
# Process each record
records = [
(p1, l1, e1, o1, d1),
(p2, l2, e2, o2, d2),
(p3, l3, e3, o3, d3),
(p4, l4, e4, o4, d4),
(p5, l5, e5, o5, d5),
(p6, l6, e6, o6, d6),
]
for record_id, (person, location, event, org, date) in enumerate(records, 1):
if person:
builder.add_entity(person, 'PERSON', record_id)
if location:
builder.add_entity(location, 'LOCATION', record_id)
if event:
builder.add_entity(event, 'EVENT', record_id)
if org:
builder.add_entity(org, 'ORGANIZATION', record_id)
if date:
builder.add_entity(date, 'DATE', record_id)
# Create list of all entity names for relationship dropdowns
entity_names = sorted([e['name'] for e in builder.entities])
# Count by type
counts = {
'PERSON': sum(1 for e in builder.entities if e['type'] == 'PERSON'),
'LOCATION': sum(1 for e in builder.entities if e['type'] == 'LOCATION'),
'EVENT': sum(1 for e in builder.entities if e['type'] == 'EVENT'),
'ORGANIZATION': sum(1 for e in builder.entities if e['type'] == 'ORGANIZATION'),
'DATE': sum(1 for e in builder.entities if e['type'] == 'DATE'),
}
# Create summary
summary = f"""
### πŸ“Š Identified Entities ({len(builder.entities)} total)
| Type | Count |
|------|-------|
| πŸ‘€ People | {counts['PERSON']} |
| πŸ“ Locations | {counts['LOCATION']} |
| πŸ“… Events | {counts['EVENT']} |
| 🏒 Organizations | {counts['ORGANIZATION']} |
| πŸ—“οΈ Dates | {counts['DATE']} |
**Entities found:** {', '.join(entity_names) if entity_names else 'None'}
➑️ Now define relationships between these entities below
"""
# Return summary and update all 10 dropdowns (5 source + 5 target)
# Each dropdown gets updated with the entity names
return (
summary,
gr.update(choices=entity_names, value=None), # source 1
gr.update(choices=entity_names, value=None), # target 1
gr.update(choices=entity_names, value=None), # source 2
gr.update(choices=entity_names, value=None), # target 2
gr.update(choices=entity_names, value=None), # source 3
gr.update(choices=entity_names, value=None), # target 3
gr.update(choices=entity_names, value=None), # source 4
gr.update(choices=entity_names, value=None), # target 4
gr.update(choices=entity_names, value=None), # source 5
gr.update(choices=entity_names, value=None), # target 5
)
def generate_network_graph(
p1, l1, e1, o1, d1,
p2, l2, e2, o2, d2,
p3, l3, e3, o3, d3,
p4, l4, e4, o4, d4,
p5, l5, e5, o5, d5,
p6, l6, e6, o6, d6,
src1, rel1, tgt1,
src2, rel2, tgt2,
src3, rel3, tgt3,
src4, rel4, tgt4,
src5, rel5, tgt5
):
"""Generate the network graph from all inputs"""
try:
builder = NetworkGraphBuilder()
# Process each record
records = [
(p1, l1, e1, o1, d1),
(p2, l2, e2, o2, d2),
(p3, l3, e3, o3, d3),
(p4, l4, e4, o4, d4),
(p5, l5, e5, o5, d5),
(p6, l6, e6, o6, d6),
]
for record_id, (person, location, event, org, date) in enumerate(records, 1):
if person:
builder.add_entity(person, 'PERSON', record_id)
if location:
builder.add_entity(location, 'LOCATION', record_id)
if event:
builder.add_entity(event, 'EVENT', record_id)
if org:
builder.add_entity(org, 'ORGANIZATION', record_id)
if date:
builder.add_entity(date, 'DATE', record_id)
# Process relationships
relationships = [
(src1, rel1, tgt1),
(src2, rel2, tgt2),
(src3, rel3, tgt3),
(src4, rel4, tgt4),
(src5, rel5, tgt5),
]
for source, rel_type, target in relationships:
if source and target:
builder.add_relationship(source, target, rel_type)
# Build graph
G = builder.build_graph()
if len(G.nodes) == 0:
empty_fig = go.Figure()
empty_fig.add_annotation(
text="No entities to display.<br>Please enter entities in Step 1 and click 'Identify Entities' first.",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16, color="#666")
)
empty_fig.update_layout(
height=400,
plot_bgcolor='#fafafa',
paper_bgcolor='#fafafa'
)
return empty_fig, "❌ **No entities to display.** Please enter entities in Step 1 first."
# Create visualization
fig = builder.create_plotly_graph(G)
# Create statistics
stats = f"""
### πŸ“ˆ Network Statistics
| Metric | Value |
|--------|-------|
| **Nodes (Entities)** | {G.number_of_nodes()} |
| **Edges (Relationships)** | {G.number_of_edges()} |
"""
if len(G.edges) == 0:
stats += "\n⚠️ **No relationships defined** - showing isolated nodes. Add relationships in Step 2!\n"
else:
density = nx.density(G)
avg_degree = sum(dict(G.degree()).values()) / G.number_of_nodes()
stats += f"| **Network Density** | {density:.3f} |\n"
stats += f"| **Avg. Connections** | {avg_degree:.2f} |\n"
# Find most connected nodes
degrees = dict(G.degree())
top_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)[:3]
stats += "\n**Most Connected Entities:**\n"
for node, degree in top_nodes:
stats += f"- **{node}**: {degree} connections\n"
return fig, stats
except Exception as e:
import traceback
error_trace = traceback.format_exc()
error_fig = go.Figure()
error_fig.add_annotation(
text=f"Error: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=14, color="red")
)
error_fig.update_layout(height=300, plot_bgcolor='#fafafa', paper_bgcolor='#fafafa')
error_msg = f"""
### ❌ Error Generating Graph
{str(e)}
<details>
<summary>Technical details</summary>
```
{error_trace}
```
</details>
"""
return error_fig, error_msg
def load_austen_example():
"""Load the Jane Austen Pride and Prejudice example"""
return (
# Record 1
"Elizabeth Bennet", "Longbourn", "Meryton Ball", "Bennet Family", "1811",
# Record 2
"Mr. Darcy", "Pemberley", "Netherfield Ball", "Darcy Estate", "1811",
# Record 3
"Jane Bennet", "Netherfield", "", "Bennet Family", "",
# Record 4
"Mr. Bingley", "Netherfield", "", "", "",
# Record 5
"Mr. Wickham", "Meryton", "", "Militia", "",
# Record 6
"Charlotte Lucas", "Hunsford", "", "", "",
)
def load_wwii_example():
"""Load a WWII history example"""
return (
# Record 1
"Winston Churchill", "London", "Battle of Britain", "War Cabinet", "1940",
# Record 2
"Clement Attlee", "London", "Potsdam Conference", "Labour Party", "1945",
# Record 3
"Field Marshal Montgomery", "North Africa", "Battle of El Alamein", "Eighth Army", "1942",
# Record 4
"Franklin D. Roosevelt", "Washington D.C.", "D-Day", "Allied Forces", "1944",
# Record 5
"", "", "", "", "",
# Record 6
"", "", "", "", "",
)
def create_interface():
with gr.Blocks(title="Basic Network Explorer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ•ΈοΈ Basic Network Explorer
Build interactive network graphs by entering entities extracted through Named Entity Recognition (NER).
Explore relationships between people, places, events, organizations and dates.
### How to use this tool:
1. **πŸ“ Enter entities** in the records below (or load an example)
2. **πŸ” Click "Identify Entities"** to collect and list all entities
3. **🀝 Define relationships** between entities using the dropdowns
4. **🎨 Click "Generate Network Graph"** to visualize
5. **πŸ‘οΈ Explore** - hover over nodes for details, zoom and pan the graph
""")
gr.HTML("""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 12px; margin: 15px 0;">
<strong style="color: #856404;">πŸ’‘ Top tip:</strong> Start with just a few entities and relationships to see how it works!
</div>
""")
# Store all entity input components
entity_inputs = []
# Example buttons
with gr.Row():
gr.Markdown("### πŸ’‘ Quick Start - Load an Example:")
with gr.Row():
austen_btn = gr.Button("πŸ“š Jane Austen (Pride & Prejudice)", variant="secondary")
wwii_btn = gr.Button("βš”οΈ WWII History", variant="secondary")
gr.HTML("<hr style='margin: 20px 0;'>")
with gr.Row():
# LEFT COLUMN: Entity Inputs
with gr.Column(scale=1):
gr.Markdown("## πŸ“š Step 1: Enter Entities")
with gr.Accordion("Records 1-4", open=True):
for i in range(4):
with gr.Group():
gr.Markdown(f"**Record {i+1}**")
with gr.Row():
person = gr.Textbox(label="πŸ‘€ Person", placeholder="e.g., Elizabeth Bennet", scale=1)
location = gr.Textbox(label="πŸ“ Location", placeholder="e.g., Longbourn", scale=1)
with gr.Row():
event = gr.Textbox(label="πŸ“… Event", placeholder="e.g., Meryton Ball", scale=1)
org = gr.Textbox(label="🏒 Organization", placeholder="e.g., Bennet Family", scale=1)
date = gr.Textbox(label="πŸ—“οΈ Date", placeholder="e.g., 1811")
entity_inputs.extend([person, location, event, org, date])
with gr.Accordion("Records 5-6 (Optional)", open=False):
for i in range(4, 6):
with gr.Group():
gr.Markdown(f"**Record {i+1}**")
with gr.Row():
person = gr.Textbox(label="πŸ‘€ Person", scale=1)
location = gr.Textbox(label="πŸ“ Location", scale=1)
with gr.Row():
event = gr.Textbox(label="πŸ“… Event", scale=1)
org = gr.Textbox(label="🏒 Organization", scale=1)
date = gr.Textbox(label="πŸ—“οΈ Date")
entity_inputs.extend([person, location, event, org, date])
collect_btn = gr.Button("πŸ” Identify Entities", variant="primary", size="lg")
entity_summary = gr.Markdown()
# RIGHT COLUMN: Relationships
with gr.Column(scale=1):
gr.Markdown("## 🀝 Step 2: Define Relationships")
gr.Markdown("*First click 'Identify Entities' to populate the dropdowns*")
# Create relationship inputs with explicit variable names
source1 = gr.Dropdown(label="From", choices=[], scale=2)
with gr.Row():
rel_type1 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
target1 = gr.Dropdown(label="To", choices=[], scale=2)
gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
source2 = gr.Dropdown(label="From", choices=[], scale=2)
with gr.Row():
rel_type2 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
target2 = gr.Dropdown(label="To", choices=[], scale=2)
gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
source3 = gr.Dropdown(label="From", choices=[], scale=2)
with gr.Row():
rel_type3 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
target3 = gr.Dropdown(label="To", choices=[], scale=2)
gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
source4 = gr.Dropdown(label="From", choices=[], scale=2)
with gr.Row():
rel_type4 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
target4 = gr.Dropdown(label="To", choices=[], scale=2)
gr.HTML("<hr style='margin: 10px 0; border-top: 1px dashed #ccc;'>")
source5 = gr.Dropdown(label="From", choices=[], scale=2)
with gr.Row():
rel_type5 = gr.Dropdown(label="Relationship Type", choices=RELATIONSHIP_TYPES, value="related_to", scale=2)
target5 = gr.Dropdown(label="To", choices=[], scale=2)
# Collect relationship inputs
relationship_inputs = [
source1, rel_type1, target1,
source2, rel_type2, target2,
source3, rel_type3, target3,
source4, rel_type4, target4,
source5, rel_type5, target5
]
gr.HTML("<hr style='margin: 30px 0;'>")
# Generate button
generate_btn = gr.Button("🎨 Generate Network Graph", variant="primary", size="lg")
# Output section
gr.Markdown("## πŸ“Š Step 3: View Results")
with gr.Row():
with gr.Column(scale=2):
network_plot = gr.Plot(label="Interactive Network Graph")
with gr.Column(scale=1):
network_stats = gr.Markdown()
# Wire up the example buttons
austen_btn.click(
fn=load_austen_example,
inputs=[],
outputs=entity_inputs
)
wwii_btn.click(
fn=load_wwii_example,
inputs=[],
outputs=entity_inputs
)
# Wire up collect entities button
collect_btn.click(
fn=collect_entities_from_records,
inputs=entity_inputs,
outputs=[
entity_summary,
source1, target1,
source2, target2,
source3, target3,
source4, target4,
source5, target5
]
)
# Wire up generate graph button
all_inputs = entity_inputs + relationship_inputs
generate_btn.click(
fn=generate_network_graph,
inputs=all_inputs,
outputs=[network_plot, network_stats]
)
# Color legend
gr.HTML("""
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; margin-top: 20px;">
<h4 style="margin-top: 0;">🎨 Entity Color Legend</h4>
<div style="display: flex; flex-wrap: wrap; gap: 15px; align-items: center;">
<span style="display: flex; align-items: center; gap: 5px;">
<span style="width: 20px; height: 20px; border-radius: 50%; background-color: #00B894; display: inline-block;"></span>
<strong>Person</strong>
</span>
<span style="display: flex; align-items: center; gap: 5px;">
<span style="width: 20px; height: 20px; border-radius: 50%; background-color: #A0E7E5; display: inline-block;"></span>
<strong>Location</strong>
</span>
<span style="display: flex; align-items: center; gap: 5px;">
<span style="width: 20px; height: 20px; border-radius: 50%; background-color: #4ECDC4; display: inline-block;"></span>
<strong>Event</strong>
</span>
<span style="display: flex; align-items: center; gap: 5px;">
<span style="width: 20px; height: 20px; border-radius: 50%; background-color: #55A3FF; display: inline-block;"></span>
<strong>Organization</strong>
</span>
<span style="display: flex; align-items: center; gap: 5px;">
<span style="width: 20px; height: 20px; border-radius: 50%; background-color: #FF6B6B; display: inline-block;"></span>
<strong>Date</strong>
</span>
</div>
</div>
""")
# Footer
gr.HTML("""
<hr style="margin: 40px 0 20px 0;">
<div style="text-align: center; color: #666; font-size: 14px; padding: 20px;">
<p><strong>Basic Network Explorer</strong> | Bodleian Libraries (University of Oxford) Sassoon Research Fellowship</p>
<p>Built with the aid of Claude</p>
</div>
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()