datamatters24 commited on
Commit
fa8cc5c
·
verified ·
1 Parent(s): fd37e45

Upload notebooks/01_exploration/11_entity_explorer.ipynb with huggingface_hub

Browse files
notebooks/01_exploration/11_entity_explorer.ipynb ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Entity Explorer\n",
8
+ "\n",
9
+ "Explore named entities extracted by the OCR/NLP pipeline:\n",
10
+ "- Top entities by frequency for each major entity type (PERSON, ORG, GPE, DATE)\n",
11
+ "- Entity type distribution\n",
12
+ "- Entity count per collection heatmap"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "metadata": {},
18
+ "source": [
19
+ "import pandas as pd\n",
20
+ "import matplotlib.pyplot as plt\n",
21
+ "import seaborn as sns\n",
22
+ "\n",
23
+ "from research_lib.db import fetch_df\n",
24
+ "from research_lib.plotting import set_style, save_fig, COLLECTION_COLORS\n",
25
+ "\n",
26
+ "set_style()\n",
27
+ "print(\"Libraries loaded.\")"
28
+ ],
29
+ "execution_count": null,
30
+ "outputs": []
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "metadata": {},
35
+ "source": [
36
+ "# Top 50 entities by frequency for PERSON, ORG, GPE, DATE\n",
37
+ "entity_types = [\"PERSON\", \"ORG\", \"GPE\", \"DATE\"]\n",
38
+ "\n",
39
+ "fig, axes = plt.subplots(2, 2, figsize=(20, 20))\n",
40
+ "axes = axes.flatten()\n",
41
+ "\n",
42
+ "for idx, etype in enumerate(entity_types):\n",
43
+ " df_top = fetch_df(f\"\"\"\n",
44
+ " SELECT entity_text, COUNT(*) AS freq\n",
45
+ " FROM entities\n",
46
+ " WHERE entity_type = '{etype}'\n",
47
+ " GROUP BY entity_text\n",
48
+ " ORDER BY freq DESC\n",
49
+ " LIMIT 50\n",
50
+ " \"\"\")\n",
51
+ "\n",
52
+ " ax = axes[idx]\n",
53
+ " if len(df_top) > 0:\n",
54
+ " # Plot top 30 for readability, store full 50 in data\n",
55
+ " plot_df = df_top.head(30)\n",
56
+ " ax.barh(\n",
57
+ " range(len(plot_df) - 1, -1, -1),\n",
58
+ " plot_df[\"freq\"],\n",
59
+ " color=sns.color_palette(\"viridis\", len(plot_df)),\n",
60
+ " )\n",
61
+ " ax.set_yticks(range(len(plot_df) - 1, -1, -1))\n",
62
+ " ax.set_yticklabels(plot_df[\"entity_text\"], fontsize=8)\n",
63
+ " ax.set_xlabel(\"Frequency\")\n",
64
+ " ax.set_title(f\"Top {etype} Entities\")\n",
65
+ "\n",
66
+ "plt.tight_layout()\n",
67
+ "save_fig(fig, \"top_entities_by_type\")\n",
68
+ "plt.show()\n",
69
+ "\n",
70
+ "# Print full top 50 for each type\n",
71
+ "for etype in entity_types:\n",
72
+ " df_top = fetch_df(f\"\"\"\n",
73
+ " SELECT entity_text, COUNT(*) AS freq\n",
74
+ " FROM entities\n",
75
+ " WHERE entity_type = '{etype}'\n",
76
+ " GROUP BY entity_text\n",
77
+ " ORDER BY freq DESC\n",
78
+ " LIMIT 50\n",
79
+ " \"\"\")\n",
80
+ " print(f\"\\n=== Top 50 {etype} ===\")\n",
81
+ " print(df_top.to_string(index=False))"
82
+ ],
83
+ "execution_count": null,
84
+ "outputs": []
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "metadata": {},
89
+ "source": [
90
+ "# Entity type distribution pie chart\n",
91
+ "df_type_dist = fetch_df(\"\"\"\n",
92
+ " SELECT entity_type, COUNT(*) AS freq\n",
93
+ " FROM entities\n",
94
+ " GROUP BY entity_type\n",
95
+ " ORDER BY freq DESC\n",
96
+ "\"\"\")\n",
97
+ "\n",
98
+ "fig, ax = plt.subplots(figsize=(10, 10))\n",
99
+ "colors = sns.color_palette(\"Set2\", len(df_type_dist))\n",
100
+ "wedges, texts, autotexts = ax.pie(\n",
101
+ " df_type_dist[\"freq\"],\n",
102
+ " labels=df_type_dist[\"entity_type\"],\n",
103
+ " autopct=\"%1.1f%%\",\n",
104
+ " colors=colors,\n",
105
+ " pctdistance=0.85,\n",
106
+ ")\n",
107
+ "for autotext in autotexts:\n",
108
+ " autotext.set_fontsize(9)\n",
109
+ "ax.set_title(\"Entity Type Distribution\")\n",
110
+ "plt.tight_layout()\n",
111
+ "save_fig(fig, \"entity_type_distribution\")\n",
112
+ "plt.show()\n",
113
+ "\n",
114
+ "print(\"\\nEntity type counts:\")\n",
115
+ "df_type_dist[\"pct\"] = (df_type_dist[\"freq\"] / df_type_dist[\"freq\"].sum() * 100).round(1)\n",
116
+ "print(df_type_dist.to_string(index=False))\n",
117
+ "print(f\"\\nTotal entities: {df_type_dist['freq'].sum():,}\")"
118
+ ],
119
+ "execution_count": null,
120
+ "outputs": []
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "metadata": {},
125
+ "source": [
126
+ "# Entity count per collection heatmap\n",
127
+ "df_heatmap = fetch_df(\"\"\"\n",
128
+ " SELECT d.source_section, e.entity_type, COUNT(*) AS freq\n",
129
+ " FROM entities e\n",
130
+ " JOIN pages p ON p.id = e.page_id\n",
131
+ " JOIN documents d ON d.id = p.document_id\n",
132
+ " GROUP BY d.source_section, e.entity_type\n",
133
+ " ORDER BY d.source_section, e.entity_type\n",
134
+ "\"\"\")\n",
135
+ "\n",
136
+ "pivot = df_heatmap.pivot_table(\n",
137
+ " index=\"source_section\", columns=\"entity_type\", values=\"freq\", fill_value=0\n",
138
+ ")\n",
139
+ "\n",
140
+ "fig, ax = plt.subplots(figsize=(14, 8))\n",
141
+ "sns.heatmap(\n",
142
+ " pivot,\n",
143
+ " annot=True,\n",
144
+ " fmt=\",.0f\",\n",
145
+ " cmap=\"YlOrRd\",\n",
146
+ " linewidths=0.5,\n",
147
+ " ax=ax,\n",
148
+ ")\n",
149
+ "ax.set_title(\"Entity Count by Collection and Type\")\n",
150
+ "ax.set_xlabel(\"Entity Type\")\n",
151
+ "ax.set_ylabel(\"Collection\")\n",
152
+ "plt.tight_layout()\n",
153
+ "save_fig(fig, \"entity_collection_heatmap\")\n",
154
+ "plt.show()"
155
+ ],
156
+ "execution_count": null,
157
+ "outputs": []
158
+ }
159
+ ],
160
+ "metadata": {
161
+ "kernelspec": {
162
+ "display_name": "Python 3",
163
+ "language": "python",
164
+ "name": "python3"
165
+ },
166
+ "language_info": {
167
+ "name": "python",
168
+ "version": "3.10.0"
169
+ }
170
+ },
171
+ "nbformat": 4,
172
+ "nbformat_minor": 5
173
+ }