datamatters24 commited on
Commit
fc3a6a0
·
verified ·
1 Parent(s): 42338b9

Upload notebooks/03_topic_classification/31_document_clustering.ipynb with huggingface_hub

Browse files
notebooks/03_topic_classification/31_document_clustering.ipynb ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 31 - Document Clustering\n",
8
+ "\n",
9
+ "Pipeline notebook for K-Means document clustering using pre-computed embeddings.\n",
10
+ "\n",
11
+ "Loads document-level embeddings (averaged page embeddings), runs MiniBatchKMeans,\n",
12
+ "and stores cluster assignments in the `document_features` table."
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {
19
+ "tags": [
20
+ "parameters"
21
+ ]
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "# Parameters\n",
26
+ "source_section = None\n",
27
+ "n_clusters = 20\n",
28
+ "batch_size = 50000"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "import sys\n",
38
+ "sys.path.insert(0, '/opt/epstein_env/research')\n",
39
+ "\n",
40
+ "import numpy as np\n",
41
+ "import pandas as pd\n",
42
+ "from sklearn.cluster import MiniBatchKMeans\n",
43
+ "from sklearn.metrics import silhouette_score\n",
44
+ "from collections import Counter\n",
45
+ "from tqdm.auto import tqdm\n",
46
+ "\n",
47
+ "from research_lib.db import fetch_df, upsert_feature\n",
48
+ "from research_lib.incremental import start_run, finish_run, get_unprocessed_documents"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# Start run\n",
58
+ "run_id = start_run(\n",
59
+ " 'document_clustering',\n",
60
+ " source_section=source_section,\n",
61
+ " parameters={'n_clusters': n_clusters, 'batch_size': batch_size},\n",
62
+ ")\n",
63
+ "print(f'Started run {run_id}')"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "# Load document-level embeddings (average of page embeddings)\n",
73
+ "where_clause = ''\n",
74
+ "params = []\n",
75
+ "if source_section:\n",
76
+ " where_clause = 'WHERE d.source_section = %s'\n",
77
+ " params = [source_section]\n",
78
+ "\n",
79
+ "sql = f\"\"\"\n",
80
+ " SELECT d.id as document_id, d.source_section,\n",
81
+ " AVG(p.embedding) as embedding\n",
82
+ " FROM documents d\n",
83
+ " JOIN pages p ON p.document_id = d.id\n",
84
+ " {where_clause}\n",
85
+ " AND p.embedding IS NOT NULL\n",
86
+ " GROUP BY d.id, d.source_section\n",
87
+ " ORDER BY d.id\n",
88
+ "\"\"\"\n",
89
+ "doc_df = fetch_df(sql, params or None)\n",
90
+ "print(f'Loaded embeddings for {len(doc_df)} documents')"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "# Convert embeddings to numpy array\n",
100
+ "embeddings = np.stack(doc_df['embedding'].values).astype(np.float32)\n",
101
+ "doc_ids = doc_df['document_id'].tolist()\n",
102
+ "\n",
103
+ "print(f'Embeddings shape: {embeddings.shape}')\n",
104
+ "\n",
105
+ "# Adjust n_clusters if we have fewer documents\n",
106
+ "actual_n_clusters = min(n_clusters, len(doc_ids))\n",
107
+ "if actual_n_clusters < n_clusters:\n",
108
+ " print(f'Adjusted n_clusters from {n_clusters} to {actual_n_clusters} (fewer documents than clusters)')"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "# Run MiniBatchKMeans clustering\n",
118
+ "print(f'Running MiniBatchKMeans with {actual_n_clusters} clusters...')\n",
119
+ "kmeans = MiniBatchKMeans(\n",
120
+ " n_clusters=actual_n_clusters,\n",
121
+ " batch_size=batch_size,\n",
122
+ " random_state=42,\n",
123
+ " n_init=3,\n",
124
+ " max_iter=300,\n",
125
+ " verbose=1,\n",
126
+ ")\n",
127
+ "cluster_labels = kmeans.fit_predict(embeddings)\n",
128
+ "print(f'Clustering complete. Inertia: {kmeans.inertia_:.2f}')"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "# Compute silhouette score (sample if dataset is large)\n",
138
+ "print('Computing silhouette score...')\n",
139
+ "if len(doc_ids) > 50000:\n",
140
+ " # Sample for efficiency\n",
141
+ " rng = np.random.RandomState(42)\n",
142
+ " sample_idx = rng.choice(len(doc_ids), size=50000, replace=False)\n",
143
+ " sil_score = silhouette_score(\n",
144
+ " embeddings[sample_idx],\n",
145
+ " cluster_labels[sample_idx],\n",
146
+ " metric='cosine',\n",
147
+ " sample_size=10000,\n",
148
+ " random_state=42,\n",
149
+ " )\n",
150
+ "else:\n",
151
+ " sil_score = silhouette_score(\n",
152
+ " embeddings,\n",
153
+ " cluster_labels,\n",
154
+ " metric='cosine',\n",
155
+ " sample_size=min(10000, len(doc_ids)),\n",
156
+ " random_state=42,\n",
157
+ " )\n",
158
+ "print(f'Silhouette score: {sil_score:.4f}')"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "# Store cluster assignments in document_features\n",
168
+ "rows = [\n",
169
+ " (\n",
170
+ " doc_id,\n",
171
+ " 'cluster_id',\n",
172
+ " str(int(cluster_label)),\n",
173
+ " None, # feature_json\n",
174
+ " )\n",
175
+ " for doc_id, cluster_label in zip(doc_ids, cluster_labels)\n",
176
+ "]\n",
177
+ "\n",
178
+ "print(f'Upserting {len(rows)} cluster assignments...')\n",
179
+ "upserted = upsert_feature(\n",
180
+ " 'document_features',\n",
181
+ " unique_cols=['document_id', 'feature_name'],\n",
182
+ " data_cols=['feature_value', 'feature_json'],\n",
183
+ " rows=rows,\n",
184
+ ")\n",
185
+ "print(f'Upserted {upserted} rows')"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "# Finish run\n",
195
+ "finish_run(run_id, documents_processed=len(doc_ids))\n",
196
+ "print(f'Run {run_id} completed.')"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "# Summary: cluster sizes\n",
206
+ "print('=== Document Clustering Summary ===')\n",
207
+ "print(f'Source section: {source_section or \"all\"}')\n",
208
+ "print(f'Documents clustered: {len(doc_ids)}')\n",
209
+ "print(f'Number of clusters: {actual_n_clusters}')\n",
210
+ "print(f'Silhouette score: {sil_score:.4f}')\n",
211
+ "print(f'Inertia: {kmeans.inertia_:.2f}')\n",
212
+ "\n",
213
+ "cluster_counts = Counter(cluster_labels)\n",
214
+ "print('\\nCluster sizes (sorted by size):')\n",
215
+ "for cluster_id, count in sorted(cluster_counts.items(), key=lambda x: x[1], reverse=True):\n",
216
+ " pct = 100 * count / len(doc_ids)\n",
217
+ " print(f' Cluster {cluster_id:3d}: {count:6d} documents ({pct:.1f}%)')"
218
+ ]
219
+ }
220
+ ],
221
+ "metadata": {
222
+ "kernelspec": {
223
+ "display_name": "Python 3",
224
+ "language": "python",
225
+ "name": "python3"
226
+ },
227
+ "language_info": {
228
+ "name": "python",
229
+ "version": "3.10.0"
230
+ }
231
+ },
232
+ "nbformat": 4,
233
+ "nbformat_minor": 5
234
+ }