datamatters24 commited on
Commit
42338b9
·
verified ·
1 Parent(s): c11f363

Upload notebooks/03_topic_classification/30_topic_modeling.ipynb with huggingface_hub

Browse files
notebooks/03_topic_classification/30_topic_modeling.ipynb ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 30 - Topic Modeling\n",
8
+ "\n",
9
+ "Pipeline notebook for BERTopic-based topic modeling using pre-computed embeddings.\n",
10
+ "\n",
11
+ "Loads document embeddings (averaged page embeddings) from the database, fits a BERTopic\n",
12
+ "model with UMAP + HDBSCAN, and stores discovered topics and document-topic assignments\n",
13
+ "in the `topics` and `document_topics` tables."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {
20
+ "tags": [
21
+ "parameters"
22
+ ]
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "# Parameters\n",
27
+ "source_section = \"doj_disclosures\"\n",
28
+ "min_topic_size = 50\n",
29
+ "nr_topics = \"auto\"\n",
30
+ "sample_size = 100000"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "import sys\n",
40
+ "sys.path.insert(0, '/opt/epstein_env/research')\n",
41
+ "\n",
42
+ "import numpy as np\n",
43
+ "import pandas as pd\n",
44
+ "from tqdm.auto import tqdm\n",
45
+ "\n",
46
+ "from research_lib.db import fetch_df, fetch_all, bulk_insert, get_conn\n",
47
+ "from research_lib.incremental import start_run, finish_run"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# Start run\n",
57
+ "run_id = start_run(\n",
58
+ " 'topic_modeling',\n",
59
+ " source_section=source_section,\n",
60
+ " parameters={\n",
61
+ " 'min_topic_size': min_topic_size,\n",
62
+ " 'nr_topics': nr_topics,\n",
63
+ " 'sample_size': sample_size,\n",
64
+ " },\n",
65
+ ")\n",
66
+ "print(f'Started run {run_id}')"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# Load document-level embeddings (average of page embeddings)\n",
76
+ "where_clause = ''\n",
77
+ "params = []\n",
78
+ "if source_section:\n",
79
+ " where_clause = 'WHERE d.source_section = %s'\n",
80
+ " params = [source_section]\n",
81
+ "\n",
82
+ "sql = f\"\"\"\n",
83
+ " SELECT d.id as document_id, d.source_section,\n",
84
+ " AVG(p.embedding) as embedding\n",
85
+ " FROM documents d\n",
86
+ " JOIN pages p ON p.document_id = d.id\n",
87
+ " {where_clause}\n",
88
+ " AND p.embedding IS NOT NULL\n",
89
+ " GROUP BY d.id, d.source_section\n",
90
+ "\"\"\"\n",
91
+ "doc_embeddings_df = fetch_df(sql, params or None)\n",
92
+ "print(f'Loaded embeddings for {len(doc_embeddings_df)} documents')"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "# Also load concatenated page text per document for topic representation\n",
102
+ "text_sql = f\"\"\"\n",
103
+ " SELECT d.id as document_id,\n",
104
+ " STRING_AGG(p.ocr_text, ' ' ORDER BY p.page_number) as full_text\n",
105
+ " FROM documents d\n",
106
+ " JOIN pages p ON p.document_id = d.id\n",
107
+ " {where_clause}\n",
108
+ " AND p.ocr_text IS NOT NULL AND p.ocr_text != ''\n",
109
+ " GROUP BY d.id\n",
110
+ "\"\"\"\n",
111
+ "text_df = fetch_df(text_sql, params or None)\n",
112
+ "print(f'Loaded text for {len(text_df)} documents')\n",
113
+ "\n",
114
+ "# Merge\n",
115
+ "merged_df = doc_embeddings_df.merge(text_df, on='document_id', how='inner')\n",
116
+ "print(f'Documents with both embeddings and text: {len(merged_df)}')"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "# Convert embeddings to numpy array\n",
126
+ "embeddings = np.stack(merged_df['embedding'].values)\n",
127
+ "docs_text = merged_df['full_text'].tolist()\n",
128
+ "doc_ids = merged_df['document_id'].tolist()\n",
129
+ "\n",
130
+ "print(f'Embeddings shape: {embeddings.shape}')\n",
131
+ "print(f'Documents: {len(docs_text)}')"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "# Sample if dataset is larger than sample_size\n",
141
+ "if len(docs_text) > sample_size:\n",
142
+ " print(f'Sampling {sample_size} documents from {len(docs_text)} for fitting...')\n",
143
+ " rng = np.random.RandomState(42)\n",
144
+ " sample_idx = rng.choice(len(docs_text), size=sample_size, replace=False)\n",
145
+ " sample_idx.sort()\n",
146
+ " fit_embeddings = embeddings[sample_idx]\n",
147
+ " fit_texts = [docs_text[i] for i in sample_idx]\n",
148
+ " fit_doc_ids = [doc_ids[i] for i in sample_idx]\n",
149
+ "else:\n",
150
+ " fit_embeddings = embeddings\n",
151
+ " fit_texts = docs_text\n",
152
+ " fit_doc_ids = doc_ids\n",
153
+ "\n",
154
+ "print(f'Fitting on {len(fit_texts)} documents')"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "# BERTopic with pre-computed embeddings\n",
164
+ "from bertopic import BERTopic\n",
165
+ "from umap import UMAP\n",
166
+ "from hdbscan import HDBSCAN\n",
167
+ "from sklearn.feature_extraction.text import CountVectorizer\n",
168
+ "\n",
169
+ "umap_model = UMAP(\n",
170
+ " n_components=5,\n",
171
+ " n_neighbors=15,\n",
172
+ " metric='cosine',\n",
173
+ " random_state=42,\n",
174
+ ")\n",
175
+ "hdbscan_model = HDBSCAN(\n",
176
+ " min_cluster_size=min_topic_size,\n",
177
+ " metric='euclidean',\n",
178
+ " prediction_data=True,\n",
179
+ ")\n",
180
+ "vectorizer = CountVectorizer(\n",
181
+ " stop_words='english',\n",
182
+ " ngram_range=(1, 2),\n",
183
+ ")\n",
184
+ "\n",
185
+ "topic_model = BERTopic(\n",
186
+ " embedding_model=None, # pre-computed\n",
187
+ " umap_model=umap_model,\n",
188
+ " hdbscan_model=hdbscan_model,\n",
189
+ " vectorizer_model=vectorizer,\n",
190
+ " nr_topics=nr_topics if nr_topics != \"auto\" else None,\n",
191
+ " verbose=True,\n",
192
+ ")\n",
193
+ "\n",
194
+ "print('Fitting BERTopic model...')\n",
195
+ "topics, probs = topic_model.fit_transform(fit_texts, fit_embeddings)\n",
196
+ "print(f'Fit complete. Found {len(set(topics)) - (1 if -1 in topics else 0)} topics.')"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "# If we sampled, transform the full dataset\n",
206
+ "if len(docs_text) > sample_size:\n",
207
+ " print('Transforming full dataset...')\n",
208
+ " all_topics, all_probs = topic_model.transform(docs_text, embeddings)\n",
209
+ "else:\n",
210
+ " all_topics = topics\n",
211
+ " all_probs = probs\n",
212
+ "\n",
213
+ "print(f'All documents assigned topics.')"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "# Store topics in topics table\n",
223
+ "topic_info = topic_model.get_topic_info()\n",
224
+ "topic_rows = []\n",
225
+ "\n",
226
+ "for _, row in topic_info.iterrows():\n",
227
+ " topic_id = row['Topic']\n",
228
+ " if topic_id == -1:\n",
229
+ " continue # Skip outlier topic\n",
230
+ "\n",
231
+ " # Get top words for this topic\n",
232
+ " topic_words = topic_model.get_topic(topic_id)\n",
233
+ " keywords = [w for w, _ in topic_words[:10]] if topic_words else []\n",
234
+ " label = ', '.join(keywords[:5]) if keywords else f'Topic {topic_id}'\n",
235
+ "\n",
236
+ " topic_rows.append((\n",
237
+ " f'bertopic_{topic_id}', # topic_name\n",
238
+ " label, # topic_label\n",
239
+ " ','.join(keywords), # keywords\n",
240
+ " int(row['Count']), # document_count\n",
241
+ " source_section, # source_section\n",
242
+ " 'topic_modeling', # model_name\n",
243
+ " ))\n",
244
+ "\n",
245
+ "if topic_rows:\n",
246
+ " inserted = bulk_insert(\n",
247
+ " 'topics',\n",
248
+ " ['topic_name', 'topic_label', 'keywords', 'document_count', 'source_section', 'model_name'],\n",
249
+ " topic_rows,\n",
250
+ " on_conflict='DO NOTHING',\n",
251
+ " )\n",
252
+ " print(f'Inserted {inserted} topics')"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "# Store document-topic assignments\n",
262
+ "assignment_rows = []\n",
263
+ "for i, (doc_id, topic_id) in enumerate(zip(doc_ids, all_topics)):\n",
264
+ " if topic_id == -1:\n",
265
+ " continue # Skip outlier assignments\n",
266
+ "\n",
267
+ " prob = float(all_probs[i]) if all_probs is not None and len(all_probs) > i else None\n",
268
+ " assignment_rows.append((\n",
269
+ " doc_id,\n",
270
+ " f'bertopic_{topic_id}',\n",
271
+ " prob,\n",
272
+ " 'topic_modeling',\n",
273
+ " ))\n",
274
+ "\n",
275
+ "if assignment_rows:\n",
276
+ " inserted = bulk_insert(\n",
277
+ " 'document_topics',\n",
278
+ " ['document_id', 'topic_name', 'probability', 'model_name'],\n",
279
+ " assignment_rows,\n",
280
+ " on_conflict='DO NOTHING',\n",
281
+ " )\n",
282
+ " print(f'Inserted {inserted} document-topic assignments')"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "metadata": {},
289
+ "outputs": [],
290
+ "source": [
291
+ "# Finish run\n",
292
+ "finish_run(run_id, documents_processed=len(doc_ids))\n",
293
+ "print(f'Run {run_id} completed.')"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "# Summary\n",
303
+ "print('=== Topic Modeling Summary ===')\n",
304
+ "print(f'Source section: {source_section or \"all\"}')\n",
305
+ "print(f'Documents processed: {len(doc_ids)}')\n",
306
+ "n_topics = len(set(all_topics)) - (1 if -1 in all_topics else 0)\n",
307
+ "n_outliers = sum(1 for t in all_topics if t == -1)\n",
308
+ "print(f'Topics discovered: {n_topics}')\n",
309
+ "print(f'Outlier documents: {n_outliers} ({100*n_outliers/len(all_topics):.1f}%)')\n",
310
+ "\n",
311
+ "print('\\nTopic overview:')\n",
312
+ "for _, row in topic_info.head(20).iterrows():\n",
313
+ " topic_id = row['Topic']\n",
314
+ " topic_words = topic_model.get_topic(topic_id)\n",
315
+ " top_words = ', '.join([w for w, _ in (topic_words[:5] if topic_words else [])])\n",
316
+ " print(f' Topic {topic_id:3d}: {row[\"Count\"]:5d} docs | {top_words}')"
317
+ ]
318
+ }
319
+ ],
320
+ "metadata": {
321
+ "kernelspec": {
322
+ "display_name": "Python 3",
323
+ "language": "python",
324
+ "name": "python3"
325
+ },
326
+ "language_info": {
327
+ "name": "python",
328
+ "version": "3.10.0"
329
+ }
330
+ },
331
+ "nbformat": 4,
332
+ "nbformat_minor": 5
333
+ }