datamatters24 commited on
Commit
9e72c3a
·
verified ·
1 Parent(s): a329614

Upload notebooks/02_entity_network/23_network_viz.ipynb with huggingface_hub

Browse files
notebooks/02_entity_network/23_network_viz.ipynb ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 23 - Network Visualization\n",
8
+ "\n",
9
+ "Interactive notebook for visualizing entity co-occurrence networks.\n",
10
+ "\n",
11
+ "Loads the network JSON exported by `22_network_analysis` and renders an interactive\n",
12
+ "Plotly network graph with community coloring, centrality-based node sizing,\n",
13
+ "and edge thickness proportional to co-occurrence weight."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "import sys\n",
23
+ "sys.path.insert(0, '/opt/epstein_env/research')\n",
24
+ "\n",
25
+ "import json\n",
26
+ "import numpy as np\n",
27
+ "import networkx as nx\n",
28
+ "import plotly.graph_objects as go\n",
29
+ "from pathlib import Path\n",
30
+ "\n",
31
+ "from research_lib.config import WEB_ASSETS_DIR\n",
32
+ "from research_lib.plotting import set_style, save_fig, COLLECTION_COLORS"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "# Select which network to load\n",
42
+ "# Change this to visualize different collections\n",
43
+ "network_file = 'network_all.json'\n",
44
+ "\n",
45
+ "network_path = WEB_ASSETS_DIR / network_file\n",
46
+ "print(f'Loading network from: {network_path}')\n",
47
+ "\n",
48
+ "with open(network_path, 'r') as f:\n",
49
+ " network_data = json.load(f)\n",
50
+ "\n",
51
+ "print(f'Nodes: {len(network_data[\"nodes\"])}')\n",
52
+ "print(f'Links: {len(network_data[\"links\"])}')"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "# Rebuild NetworkX graph from JSON for layout computation\n",
62
+ "G = nx.Graph()\n",
63
+ "\n",
64
+ "for node in network_data['nodes']:\n",
65
+ " G.add_node(\n",
66
+ " node['id'],\n",
67
+ " label=node.get('label', node['id']),\n",
68
+ " type=node.get('type', 'unknown'),\n",
69
+ " size=node.get('size', 1),\n",
70
+ " community=node.get('community', 0),\n",
71
+ " centrality=node.get('centrality', 0),\n",
72
+ " )\n",
73
+ "\n",
74
+ "for link in network_data['links']:\n",
75
+ " G.add_edge(link['source'], link['target'], weight=link.get('weight', 1))\n",
76
+ "\n",
77
+ "print(f'Graph reconstructed: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges')"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "# Compute layout using spring layout (Fruchterman-Reingold)\n",
87
+ "print('Computing layout...')\n",
88
+ "pos = nx.spring_layout(G, k=1.5/np.sqrt(G.number_of_nodes()), iterations=50, seed=42, weight='weight')\n",
89
+ "print('Layout computed.')"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "# Community color mapping\n",
99
+ "community_ids = sorted(set(G.nodes[n]['community'] for n in G.nodes))\n",
100
+ "# Generate distinct colors for communities\n",
101
+ "import plotly.express as px\n",
102
+ "color_palette = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel + px.colors.qualitative.Bold\n",
103
+ "community_colors = {cid: color_palette[i % len(color_palette)] for i, cid in enumerate(community_ids)}\n",
104
+ "\n",
105
+ "print(f'Communities: {len(community_ids)}')"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "# Build edge traces\n",
115
+ "edge_traces = []\n",
116
+ "weights = [G.edges[u, v]['weight'] for u, v in G.edges]\n",
117
+ "max_weight = max(weights) if weights else 1\n",
118
+ "\n",
119
+ "edge_x = []\n",
120
+ "edge_y = []\n",
121
+ "for u, v in G.edges:\n",
122
+ " x0, y0 = pos[u]\n",
123
+ " x1, y1 = pos[v]\n",
124
+ " edge_x.extend([x0, x1, None])\n",
125
+ " edge_y.extend([y0, y1, None])\n",
126
+ "\n",
127
+ "edge_trace = go.Scatter(\n",
128
+ " x=edge_x, y=edge_y,\n",
129
+ " line=dict(width=0.5, color='#cccccc'),\n",
130
+ " hoverinfo='none',\n",
131
+ " mode='lines',\n",
132
+ " name='Edges',\n",
133
+ ")\n",
134
+ "\n",
135
+ "print(f'Edge trace built with {len(G.edges)} edges')"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "# Build node trace\n",
145
+ "node_x = []\n",
146
+ "node_y = []\n",
147
+ "node_text = []\n",
148
+ "node_color = []\n",
149
+ "node_size = []\n",
150
+ "\n",
151
+ "centralities = [G.nodes[n]['centrality'] for n in G.nodes]\n",
152
+ "max_centrality = max(centralities) if centralities else 1\n",
153
+ "\n",
154
+ "for node in G.nodes:\n",
155
+ " x, y = pos[node]\n",
156
+ " node_x.append(x)\n",
157
+ " node_y.append(y)\n",
158
+ "\n",
159
+ " data = G.nodes[node]\n",
160
+ " degree = G.degree(node)\n",
161
+ " hover_text = (\n",
162
+ " f\"<b>{data['label']}</b><br>\"\n",
163
+ " f\"Type: {data['type']}<br>\"\n",
164
+ " f\"Degree: {degree}<br>\"\n",
165
+ " f\"Community: {data['community']}<br>\"\n",
166
+ " f\"Centrality: {data['centrality']:.6f}\"\n",
167
+ " )\n",
168
+ " node_text.append(hover_text)\n",
169
+ " node_color.append(community_colors.get(data['community'], '#999999'))\n",
170
+ "\n",
171
+ " # Node size proportional to centrality, with min/max bounds\n",
172
+ " size = 5 + 30 * (data['centrality'] / max_centrality) if max_centrality > 0 else 10\n",
173
+ " node_size.append(size)\n",
174
+ "\n",
175
+ "node_trace = go.Scatter(\n",
176
+ " x=node_x, y=node_y,\n",
177
+ " mode='markers',\n",
178
+ " hoverinfo='text',\n",
179
+ " hovertext=node_text,\n",
180
+ " marker=dict(\n",
181
+ " size=node_size,\n",
182
+ " color=node_color,\n",
183
+ " line=dict(width=1, color='#ffffff'),\n",
184
+ " ),\n",
185
+ " name='Entities',\n",
186
+ ")\n",
187
+ "\n",
188
+ "print(f'Node trace built with {G.number_of_nodes()} nodes')"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "# Create the figure\n",
198
+ "fig = go.Figure(\n",
199
+ " data=[edge_trace, node_trace],\n",
200
+ " layout=go.Layout(\n",
201
+ " title=dict(\n",
202
+ " text='Entity Co-occurrence Network',\n",
203
+ " font=dict(size=20),\n",
204
+ " ),\n",
205
+ " showlegend=False,\n",
206
+ " hovermode='closest',\n",
207
+ " xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),\n",
208
+ " yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),\n",
209
+ " plot_bgcolor='white',\n",
210
+ " width=1200,\n",
211
+ " height=800,\n",
212
+ " margin=dict(l=20, r=20, t=60, b=20),\n",
213
+ " annotations=[\n",
214
+ " dict(\n",
215
+ " text=f'Nodes: {G.number_of_nodes()} | Edges: {G.number_of_edges()} | Communities: {len(community_ids)}',\n",
216
+ " xref='paper', yref='paper',\n",
217
+ " x=0.5, y=-0.02,\n",
218
+ " showarrow=False,\n",
219
+ " font=dict(size=12, color='gray'),\n",
220
+ " ),\n",
221
+ " ],\n",
222
+ " ),\n",
223
+ ")\n",
224
+ "\n",
225
+ "fig.show()"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "# Save as HTML for standalone viewing\n",
235
+ "output_path = save_fig(fig, 'entity_network_interactive', formats=('html',))\n",
236
+ "print(f'Saved interactive visualization to: {output_path}')"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "# Summary table: top entities\n",
246
+ "import pandas as pd\n",
247
+ "\n",
248
+ "node_stats = []\n",
249
+ "for node in G.nodes:\n",
250
+ " data = G.nodes[node]\n",
251
+ " node_stats.append({\n",
252
+ " 'entity': data['label'],\n",
253
+ " 'type': data['type'],\n",
254
+ " 'degree': G.degree(node),\n",
255
+ " 'community': data['community'],\n",
256
+ " 'centrality': data['centrality'],\n",
257
+ " })\n",
258
+ "\n",
259
+ "stats_df = pd.DataFrame(node_stats).sort_values('centrality', ascending=False)\n",
260
+ "print('Top 30 entities by centrality:')\n",
261
+ "stats_df.head(30)"
262
+ ]
263
+ }
264
+ ],
265
+ "metadata": {
266
+ "kernelspec": {
267
+ "display_name": "Python 3",
268
+ "language": "python",
269
+ "name": "python3"
270
+ },
271
+ "language_info": {
272
+ "name": "python",
273
+ "version": "3.10.0"
274
+ }
275
+ },
276
+ "nbformat": 4,
277
+ "nbformat_minor": 5
278
+ }