datamatters24 commited on
Commit
43bc5c8
·
verified ·
1 Parent(s): 8f695bc

Upload notebooks/03_topic_classification/32_keyword_extraction.ipynb with huggingface_hub

Browse files
notebooks/03_topic_classification/32_keyword_extraction.ipynb ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 32 - Keyword Extraction\n",
8
+ "\n",
9
+ "Pipeline notebook for TF-IDF keyword extraction from document OCR text.\n",
10
+ "\n",
11
+ "Concatenates page-level OCR text per document, fits a TF-IDF vectorizer across the\n",
12
+ "corpus, and extracts the top-K keywords per document. Results stored in the\n",
13
+ "`document_keywords` table."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {
20
+ "tags": [
21
+ "parameters"
22
+ ]
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "# Parameters\n",
27
+ "source_section = None\n",
28
+ "top_k = 20\n",
29
+ "batch_size = 5000"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import sys\n",
39
+ "sys.path.insert(0, '/opt/epstein_env/research')\n",
40
+ "\n",
41
+ "import numpy as np\n",
42
+ "import pandas as pd\n",
43
+ "from sklearn.feature_extraction.text import TfidfVectorizer\n",
44
+ "from collections import Counter, defaultdict\n",
45
+ "from tqdm.auto import tqdm\n",
46
+ "\n",
47
+ "from research_lib.db import fetch_df, bulk_insert\n",
48
+ "from research_lib.incremental import start_run, finish_run, get_unprocessed_documents"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# Start run\n",
58
+ "run_id = start_run(\n",
59
+ " 'keyword_extraction',\n",
60
+ " source_section=source_section,\n",
61
+ " parameters={'top_k': top_k, 'batch_size': batch_size},\n",
62
+ ")\n",
63
+ "print(f'Started run {run_id}')"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "# Load concatenated page text per document\n",
73
+ "where_clause = ''\n",
74
+ "params = []\n",
75
+ "if source_section:\n",
76
+ " where_clause = 'WHERE d.source_section = %s'\n",
77
+ " params = [source_section]\n",
78
+ "\n",
79
+ "sql = f\"\"\"\n",
80
+ " SELECT d.id as document_id, d.source_section,\n",
81
+ " STRING_AGG(p.ocr_text, ' ' ORDER BY p.page_number) as full_text\n",
82
+ " FROM documents d\n",
83
+ " JOIN pages p ON p.document_id = d.id\n",
84
+ " {where_clause}\n",
85
+ " AND p.ocr_text IS NOT NULL AND p.ocr_text != ''\n",
86
+ " GROUP BY d.id, d.source_section\n",
87
+ " ORDER BY d.id\n",
88
+ "\"\"\"\n",
89
+ "docs_df = fetch_df(sql, params or None)\n",
90
+ "print(f'Loaded text for {len(docs_df)} documents')"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "# Fit TF-IDF vectorizer\n",
100
+ "print('Fitting TF-IDF vectorizer...')\n",
101
+ "tfidf = TfidfVectorizer(\n",
102
+ " max_features=50000,\n",
103
+ " max_df=0.95,\n",
104
+ " min_df=2,\n",
105
+ " stop_words='english',\n",
106
+ " ngram_range=(1, 2),\n",
107
+ " sublinear_tf=True,\n",
108
+ " dtype=np.float32,\n",
109
+ ")\n",
110
+ "\n",
111
+ "tfidf_matrix = tfidf.fit_transform(docs_df['full_text'].fillna(''))\n",
112
+ "feature_names = np.array(tfidf.get_feature_names_out())\n",
113
+ "print(f'TF-IDF matrix shape: {tfidf_matrix.shape}')\n",
114
+ "print(f'Vocabulary size: {len(feature_names)}')"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "# Extract top_k keywords per document in batches\n",
124
+ "all_rows = []\n",
125
+ "doc_ids = docs_df['document_id'].tolist()\n",
126
+ "\n",
127
+ "n_batches = (len(doc_ids) + batch_size - 1) // batch_size\n",
128
+ "for batch_idx in tqdm(range(n_batches), desc='Extracting keywords'):\n",
129
+ " start = batch_idx * batch_size\n",
130
+ " end = min(start + batch_size, len(doc_ids))\n",
131
+ "\n",
132
+ " batch_matrix = tfidf_matrix[start:end]\n",
133
+ " batch_doc_ids = doc_ids[start:end]\n",
134
+ "\n",
135
+ " for i in range(batch_matrix.shape[0]):\n",
136
+ " row = batch_matrix.getrow(i)\n",
137
+ " if row.nnz == 0:\n",
138
+ " continue\n",
139
+ "\n",
140
+ " # Get top_k indices by TF-IDF score\n",
141
+ " data = row.toarray().flatten()\n",
142
+ " top_indices = data.argsort()[::-1][:top_k]\n",
143
+ "\n",
144
+ " for rank, idx in enumerate(top_indices, 1):\n",
145
+ " score = float(data[idx])\n",
146
+ " if score <= 0:\n",
147
+ " break\n",
148
+ " keyword = feature_names[idx]\n",
149
+ " all_rows.append((\n",
150
+ " batch_doc_ids[i],\n",
151
+ " keyword,\n",
152
+ " score,\n",
153
+ " rank,\n",
154
+ " ))\n",
155
+ "\n",
156
+ "print(f'Total keyword entries: {len(all_rows)}')"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "# Insert into document_keywords table\n",
166
+ "if all_rows:\n",
167
+ " # Insert in chunks to avoid memory issues\n",
168
+ " chunk_size = 50000\n",
169
+ " total_inserted = 0\n",
170
+ " for i in tqdm(range(0, len(all_rows), chunk_size), desc='Inserting keywords'):\n",
171
+ " chunk = all_rows[i:i + chunk_size]\n",
172
+ " inserted = bulk_insert(\n",
173
+ " 'document_keywords',\n",
174
+ " ['document_id', 'keyword', 'tfidf_score', 'rank'],\n",
175
+ " chunk,\n",
176
+ " on_conflict='DO NOTHING',\n",
177
+ " )\n",
178
+ " total_inserted += inserted\n",
179
+ "\n",
180
+ " print(f'Inserted {total_inserted} keyword rows')\n",
181
+ "else:\n",
182
+ " print('No keywords to insert.')"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "# Finish run\n",
192
+ "finish_run(run_id, documents_processed=len(doc_ids))\n",
193
+ "print(f'Run {run_id} completed.')"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "# Top 20 global keywords per collection\n",
203
+ "print('=== Keyword Extraction Summary ===')\n",
204
+ "print(f'Source section: {source_section or \"all\"}')\n",
205
+ "print(f'Documents processed: {len(doc_ids)}')\n",
206
+ "print(f'Total keyword entries: {len(all_rows)}')\n",
207
+ "\n",
208
+ "# Aggregate keywords by collection\n",
209
+ "keyword_scores_by_section = defaultdict(lambda: Counter())\n",
210
+ "for i, row in docs_df.iterrows():\n",
211
+ " section = row['source_section']\n",
212
+ " doc_idx = i\n",
213
+ " tfidf_row = tfidf_matrix.getrow(doc_idx)\n",
214
+ " if tfidf_row.nnz > 0:\n",
215
+ " data = tfidf_row.toarray().flatten()\n",
216
+ " top_indices = data.argsort()[::-1][:5]\n",
217
+ " for idx in top_indices:\n",
218
+ " if data[idx] > 0:\n",
219
+ " keyword_scores_by_section[section][feature_names[idx]] += data[idx]\n",
220
+ "\n",
221
+ "print('\\nTop 20 keywords per collection:')\n",
222
+ "for section in sorted(keyword_scores_by_section.keys()):\n",
223
+ " print(f'\\n {section}:')\n",
224
+ " for keyword, score in keyword_scores_by_section[section].most_common(20):\n",
225
+ " print(f' {keyword:30s} {score:.2f}')"
226
+ ]
227
+ }
228
+ ],
229
+ "metadata": {
230
+ "kernelspec": {
231
+ "display_name": "Python 3",
232
+ "language": "python",
233
+ "name": "python3"
234
+ },
235
+ "language_info": {
236
+ "name": "python",
237
+ "version": "3.10.0"
238
+ }
239
+ },
240
+ "nbformat": 4,
241
+ "nbformat_minor": 5
242
+ }