SorrelC's picture
Update app.py
e086f57 verified
raw
history blame
25.3 kB
import gradio as gr
import plotly.graph_objects as go
import networkx as nx
import pandas as pd
from collections import defaultdict
# Entity type colors
ENTITY_COLORS = {
'PERSON': '#00B894', # Green
'LOCATION': '#A0E7E5', # Light Cyan
'EVENT': '#4ECDC4', # Teal
'ORGANIZATION': '#55A3FF', # Light Blue
'DATE': '#FF6B6B' # Red
}
# Relationship types for dropdown
RELATIONSHIP_TYPES = [
'works_with',
'located_in',
'participated_in',
'member_of',
'occurred_at',
'employed_by',
'founded',
'attended',
'knows',
'related_to',
'collaborates_with',
'other'
]
class NetworkGraphBuilder:
def __init__(self):
self.entities = []
self.relationships = []
def add_entity(self, name, entity_type, record_id):
"""Add an entity to the collection"""
if name.strip():
self.entities.append({
'name': name.strip(),
'type': entity_type,
'record_id': record_id
})
def add_relationship(self, source, target, rel_type):
"""Add a relationship between entities"""
if source and target and source != target:
self.relationships.append({
'source': source.strip(),
'target': target.strip(),
'type': rel_type
})
def build_graph(self):
"""Build NetworkX graph from entities and relationships"""
G = nx.Graph()
# Add nodes with attributes
for entity in self.entities:
G.add_node(
entity['name'],
entity_type=entity['type'],
record_id=entity['record_id']
)
# Add edges
for rel in self.relationships:
if rel['source'] in G.nodes and rel['target'] in G.nodes:
G.add_edge(
rel['source'],
rel['target'],
relationship=rel['type']
)
return G
def create_plotly_graph(self, G, layout_type='spring'):
"""Create interactive Plotly visualization"""
if len(G.nodes) == 0:
return None
# Choose layout
if layout_type == 'spring':
pos = nx.spring_layout(G, k=2, iterations=50)
elif layout_type == 'circular':
pos = nx.circular_layout(G)
elif layout_type == 'kamada_kawai':
pos = nx.kamada_kawai_layout(G)
else:
pos = nx.shell_layout(G)
# Create edge traces
edge_traces = []
edge_labels = []
for edge in G.edges(data=True):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
# Edge line
edge_trace = go.Scatter(
x=[x0, x1, None],
y=[y0, y1, None],
mode='lines',
line=dict(width=2, color='#888'),
hoverinfo='none',
showlegend=False
)
edge_traces.append(edge_trace)
# Edge label (relationship type)
rel_type = edge[2].get('relationship', '')
edge_label = go.Scatter(
x=[(x0 + x1) / 2],
y=[(y0 + y1) / 2],
mode='text',
text=[rel_type],
textfont=dict(size=10, color='#555'),
hoverinfo='text',
hovertext=f"{edge[0]}{rel_type}{edge[1]}",
showlegend=False
)
edge_labels.append(edge_label)
# Create node traces (one per entity type for legend)
node_traces = {}
for node, data in G.nodes(data=True):
entity_type = data.get('entity_type', 'UNKNOWN')
if entity_type not in node_traces:
node_traces[entity_type] = {
'x': [],
'y': [],
'text': [],
'hovertext': [],
'degree': []
}
x, y = pos[node]
node_traces[entity_type]['x'].append(x)
node_traces[entity_type]['y'].append(y)
node_traces[entity_type]['text'].append(node)
# Create hover text with connections
connections = list(G.neighbors(node))
hover_info = f"<b>{node}</b><br>"
hover_info += f"Type: {entity_type}<br>"
hover_info += f"Connections: {len(connections)}<br>"
if connections:
hover_info += f"Connected to: {', '.join(connections[:5])}"
if len(connections) > 5:
hover_info += f"... and {len(connections) - 5} more"
node_traces[entity_type]['hovertext'].append(hover_info)
node_traces[entity_type]['degree'].append(G.degree(node))
# Create Plotly traces for each entity type
data = edge_traces + edge_labels
for entity_type, trace_data in node_traces.items():
# Calculate node sizes based on degree
max_degree = max(trace_data['degree']) if trace_data['degree'] else 1
sizes = [20 + (degree / max_degree) * 30 for degree in trace_data['degree']]
node_trace = go.Scatter(
x=trace_data['x'],
y=trace_data['y'],
mode='markers+text',
marker=dict(
size=sizes,
color=ENTITY_COLORS.get(entity_type, '#CCCCCC'),
line=dict(width=2, color='white')
),
text=trace_data['text'],
textposition='top center',
textfont=dict(size=10, color='#333'),
hovertext=trace_data['hovertext'],
hoverinfo='text',
name=entity_type,
showlegend=True
)
data.append(node_trace)
# Create figure
fig = go.Figure(
data=data,
layout=go.Layout(
title=dict(
text='<b>Entity Network Graph</b><br><sub>Node size indicates number of connections</sub>',
x=0.5,
xanchor='center'
),
showlegend=True,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=80),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='#fafafa',
height=700,
legend=dict(
title=dict(text='<b>Entity Types</b>'),
orientation='v',
yanchor='top',
y=1,
xanchor='left',
x=1.02
)
)
)
return fig
def collect_entities_from_records(*args):
"""Collect all entities from the input fields"""
builder = NetworkGraphBuilder()
# Each record has 5 entity fields (person, location, event, org, date)
num_records = 6
fields_per_record = 5
for i in range(num_records):
record_id = i + 1
base_idx = i * fields_per_record
# Extract entities for this record
person = args[base_idx] if base_idx < len(args) else ""
location = args[base_idx + 1] if base_idx + 1 < len(args) else ""
event = args[base_idx + 2] if base_idx + 2 < len(args) else ""
org = args[base_idx + 3] if base_idx + 3 < len(args) else ""
date = args[base_idx + 4] if base_idx + 4 < len(args) else ""
if person:
builder.add_entity(person, 'PERSON', record_id)
if location:
builder.add_entity(location, 'LOCATION', record_id)
if event:
builder.add_entity(event, 'EVENT', record_id)
if org:
builder.add_entity(org, 'ORGANIZATION', record_id)
if date:
builder.add_entity(date, 'DATE', record_id)
# Create list of all entity names for relationship dropdowns
entity_names = [e['name'] for e in builder.entities]
# Create summary
summary = f"""
### 📊 Identified Entities
- **Total entities:** {len(builder.entities)}
- **People:** {sum(1 for e in builder.entities if e['type'] == 'PERSON')}
- **Locations:** {sum(1 for e in builder.entities if e['type'] == 'LOCATION')}
- **Events:** {sum(1 for e in builder.entities if e['type'] == 'EVENT')}
- **Organizations:** {sum(1 for e in builder.entities if e['type'] == 'ORGANIZATION')}
- **Dates:** {sum(1 for e in builder.entities if e['type'] == 'DATE')}
Now define relationships between these entities on the right →
"""
# Return summary and update all dropdowns (5 relationships × 2 dropdowns each = 10 updates)
dropdown_updates = [gr.update(choices=entity_names, value=None)] * 10
return [summary] + dropdown_updates
def generate_network_graph(*args):
"""Generate the network graph from all inputs"""
try:
builder = NetworkGraphBuilder()
# Collect entities (first 30 args: 6 records × 5 fields)
num_records = 6
fields_per_record = 5
for i in range(num_records):
record_id = i + 1
base_idx = i * fields_per_record
person = args[base_idx] if base_idx < len(args) else ""
location = args[base_idx + 1] if base_idx + 1 < len(args) else ""
event = args[base_idx + 2] if base_idx + 2 < len(args) else ""
org = args[base_idx + 3] if base_idx + 3 < len(args) else ""
date = args[base_idx + 4] if base_idx + 4 < len(args) else ""
if person:
builder.add_entity(person, 'PERSON', record_id)
if location:
builder.add_entity(location, 'LOCATION', record_id)
if event:
builder.add_entity(event, 'EVENT', record_id)
if org:
builder.add_entity(org, 'ORGANIZATION', record_id)
if date:
builder.add_entity(date, 'DATE', record_id)
# Collect relationships (next args: 5 relationships × 3 fields)
relationship_start = 30
num_relationships = 5
for i in range(num_relationships):
base_idx = relationship_start + (i * 3)
source = args[base_idx] if base_idx < len(args) else None
rel_type = args[base_idx + 1] if base_idx + 1 < len(args) else None
target = args[base_idx + 2] if base_idx + 2 < len(args) else None
if source and target:
builder.add_relationship(source, target, rel_type)
# Get layout type (last arg)
layout_type = args[-1] if len(args) > relationship_start else 'spring'
# Build graph
G = builder.build_graph()
if len(G.nodes) == 0:
return None, "❌ **No entities to display.** Please enter entities in Step 1 and click 'Identify Entities' first."
# Create visualization (even if no relationships, show isolated nodes)
fig = builder.create_plotly_graph(G, layout_type)
# Create statistics
stats = f"""
### 📈 Network Statistics
- **Nodes (Entities):** {G.number_of_nodes()}
- **Edges (Relationships):** {G.number_of_edges()}
"""
if len(G.edges) == 0:
stats += "\n⚠️ **No relationships defined** - showing isolated nodes only.\n"
stats += "\n*Define relationships in Step 2 to see connections between entities.*\n"
else:
stats += f"- **Network Density:** {nx.density(G):.3f}\n"
stats += f"- **Average Connections per Node:** {sum(dict(G.degree()).values()) / G.number_of_nodes():.2f}\n"
if G.number_of_edges() > 0:
# Find most connected nodes
degrees = dict(G.degree())
top_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)[:3]
stats += "\n**Most Connected Entities:**\n"
for node, degree in top_nodes:
stats += f"- {node}: {degree} connections\n"
return fig, stats
except Exception as e:
error_msg = f"""
### ❌ Error Generating Graph
An error occurred: {str(e)}
**Troubleshooting:**
1. Make sure you've clicked "Identify Entities" first
2. Check that you have at least one entity entered
3. If problem persists, try refreshing the page
"""
return None, error_msg
def create_interface():
with gr.Blocks(title="Basic Network Explorer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Basic Network Explorer
Build interactive social network graphs by entering entities extracted through Named Entity Recognition (NER).
This tool demonstrates how NER can be used to visualize relationships and connections in text data.
### How to use this tool:
1. **📝 Enter entities** in the records on the left (people, locations, events, organizations, dates)
2. **🔗 Click "Identify Entities"** to populate the relationship dropdowns
3. **🤝 Define relationships** on the right by selecting entities and connection types
4. **🎨 Click "Generate Network Graph"** to visualize your network
5. **👁️ Explore** the interactive graph - hover over nodes and edges for details
6. **🔄 Refresh the page** to start over with new data
""")
# Add tip box
gr.HTML("""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 12px; margin: 15px 0;">
<strong style="color: #856404;">💡 Top tip:</strong> This tool works best when you have already identified entities from text using NER. Try the NER Explorer Tool first to extract entities automatically!
</div>
""")
# Entity input section
entity_inputs = []
# Two-column layout: Entities on left, Relationships on right
with gr.Row():
# LEFT COLUMN: Entity Inputs
with gr.Column(scale=1):
with gr.Accordion("📚 Step 1: Enter Entities from Your Records", open=True):
# First 4 records (always visible)
for i in range(4):
with gr.Group():
gr.Markdown(f"### Record {i+1}")
with gr.Row():
person = gr.Textbox(label="👤 Person", placeholder="e.g., Winston Churchill")
location = gr.Textbox(label="📍 Location", placeholder="e.g., London")
with gr.Row():
event = gr.Textbox(label="📅 Event", placeholder="e.g., Battle of Britain")
org = gr.Textbox(label="🏢 Organization", placeholder="e.g., Royal Air Force")
date = gr.Textbox(label="🗓️ Date", placeholder="e.g., 1940")
entity_inputs.extend([person, location, event, org, date])
# Additional records (collapsible)
with gr.Accordion("➕ Additional Records (5-6)", open=False):
for i in range(4, 6):
with gr.Group():
gr.Markdown(f"### Record {i+1}")
with gr.Row():
person = gr.Textbox(label="👤 Person", placeholder="e.g., Winston Churchill")
location = gr.Textbox(label="📍 Location", placeholder="e.g., London")
with gr.Row():
event = gr.Textbox(label="📅 Event", placeholder="e.g., Battle of Britain")
org = gr.Textbox(label="🏢 Organization", placeholder="e.g., Royal Air Force")
date = gr.Textbox(label="🗓️ Date", placeholder="e.g., 1940")
entity_inputs.extend([person, location, event, org, date])
collect_btn = gr.Button("🔍 Identify Entities", variant="primary", size="lg")
entity_summary = gr.Markdown()
# RIGHT COLUMN: Relationship Builder (ALWAYS VISIBLE)
with gr.Column(scale=1):
with gr.Accordion("🤝 Step 2: Define Relationships Between Entities", open=True):
gr.Markdown("*First identify entities, then define relationships below:*")
relationship_inputs = []
for i in range(5):
with gr.Row():
source = gr.Dropdown(label=f"From", choices=[], interactive=True, scale=2)
rel_type = gr.Dropdown(
label="Type",
choices=RELATIONSHIP_TYPES,
value="related_to",
interactive=True,
scale=2
)
target = gr.Dropdown(label=f"To", choices=[], interactive=True, scale=2)
relationship_inputs.extend([source, rel_type, target])
with gr.Accordion("🎨 Step 3: Customize and Generate", open=True):
layout_type = gr.Dropdown(
label="Graph Layout",
choices=['spring', 'circular', 'kamada_kawai', 'shell'],
value='spring',
info="Choose how nodes are arranged"
)
generate_btn = gr.Button("🔍 Generate Network Graph", variant="primary", size="lg")
# Output section
gr.HTML("<hr style='margin: 30px 0;'>")
with gr.Row():
network_stats = gr.Markdown()
with gr.Row():
network_plot = gr.Plot(label="Interactive Network Graph")
# Examples
with gr.Column():
gr.Markdown("""
### 💡 No example entities to test? No problem!
Simply click on one of the examples provided below, and the fields will be populated for you.
""", elem_id="examples-heading")
gr.Examples(
examples=[
[
# === ENTITY RECORDS ===
# Record 1
"Winston Churchill", "London", "Battle of Britain", "War Cabinet", "1940",
# Record 2
"Clement Attlee", "London", "Potsdam Conference", "Labour Party", "1945",
# Record 3
"Field Marshal Montgomery", "North Africa", "Battle of El Alamein", "Eighth Army", "1942",
# Record 4
"Winston Churchill", "Yalta", "Yalta Conference", "War Cabinet", "February 1945",
# Record 5
"King George VI", "London", "Victory in Europe Day", "British Monarchy", "May 1945",
# Record 6
"Field Marshal Montgomery", "Lüneburg Heath", "German Surrender", "British Army", "May 1945",
# === RELATIONSHIPS ===
# Relationship 1
"Winston Churchill", "works_with", "Clement Attlee",
# Relationship 2
"Winston Churchill", "participated_in", "Battle of Britain",
# Relationship 3
"Field Marshal Montgomery", "participated_in", "Battle of El Alamein",
# Relationship 4
"Winston Churchill", "participated_in", "Yalta Conference",
# Relationship 5
"Clement Attlee", "participated_in", "Potsdam Conference",
# Layout type
"spring"
],
[
# === ENTITY RECORDS ===
# Record 1 - Pride and Prejudice
"Elizabeth Bennet", "Longbourn", "Meryton Assembly", "", "Autumn 1811",
# Record 2
"Mr Darcy", "Pemberley", "Meryton Assembly", "", "Autumn 1811",
# Record 3
"Jane Bennet", "Longbourn", "Netherfield Ball", "", "November 1811",
# Record 4
"Mr Bingley", "Netherfield", "Netherfield Ball", "", "November 1811",
# Record 5
"Elizabeth Bennet", "Rosings", "Easter Visit", "", "Spring 1812",
# Record 6
"Mr Darcy", "Rosings", "First Proposal", "", "Spring 1812",
# === RELATIONSHIPS ===
# Relationship 1
"Elizabeth Bennet", "knows", "Mr Darcy",
# Relationship 2
"Jane Bennet", "knows", "Mr Bingley",
# Relationship 3
"Elizabeth Bennet", "located_in", "Longbourn",
# Relationship 4
"Mr Darcy", "located_in", "Pemberley",
# Relationship 5
"Elizabeth Bennet", "participated_in", "Meryton Assembly",
# Layout type
"spring"
]
],
inputs=entity_inputs + relationship_inputs + [layout_type],
label="Examples"
)
# Add custom CSS to match NER tool styling
gr.HTML("""
<style>
/* Make the Examples label text black */
.gradio-examples-label {
color: black !important;
}
h4.examples-label, .examples-label {
color: black !important;
}
#examples-heading + div label,
#examples-heading + div .label-text {
color: black !important;
}
</style>
""")
# Wire up the interface
# Collect entities button - updates the relationship dropdowns
collect_btn.click(
fn=collect_entities_from_records,
inputs=entity_inputs,
outputs=[entity_summary] + relationship_inputs[::3] + relationship_inputs[2::3] # Update source and target dropdowns
)
# Generate graph button
all_inputs = entity_inputs + relationship_inputs + [layout_type]
generate_btn.click(
fn=generate_network_graph,
inputs=all_inputs,
outputs=[network_plot, network_stats]
)
# Information footer
gr.HTML("""
<hr style="margin-top: 40px; margin-bottom: 20px;">
<div style="background-color: #f8f9fa; padding: 20px; border-radius: 8px; margin-top: 20px;">
<h4 style="margin-top: 0;">ℹ️ About This Tool</h4>
<p style="font-size: 14px; line-height: 1.8;">
This tool demonstrates how <strong>Named Entity Recognition (NER)</strong> can be combined with
<strong>network analysis</strong> to visualize relationships in text data. In real-world applications,
entities would be automatically extracted from text using NER models, and relationships could be
identified through co-occurrence analysis, dependency parsing, or machine learning.
</p>
<p style="font-size: 14px; line-height: 1.8; margin-bottom: 0;">
<strong>Built with:</strong> Gradio, NetworkX, and Plotly |
<strong>Graph Layouts:</strong> Spring (force-directed), Circular, Kamada-Kawai, Shell
</p>
</div>
<br>
<hr style="margin-top: 40px; margin-bottom: 20px;">
<div style="background-color: #f8f9fa; padding: 20px; border-radius: 8px; margin-top: 20px; text-align: center;">
<p style="font-size: 14px; line-height: 1.8; margin: 0;">
This <strong>Basic Network Explorer</strong> tool was created as part of a Bodleian Libraries (University of Oxford) Sassoon Research Fellowship.
</p><br><br>
<p style="font-size: 14px; line-height: 1.8; margin: 0;">
The code for this tool was built with the aid of Claude Sonnet 4.5.
</p>
</div>
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()