anujjuna commited on
Commit
c8c01fa
·
verified ·
1 Parent(s): 8c6e466

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +98 -36
tools.py CHANGED
@@ -2,15 +2,23 @@
2
  tools.py
3
  --------
4
  Topic modeling module using BERTopic for analyzing research paper abstracts and titles.
5
- Heavy imports are lazy-loaded inside functions to stay within 2GB RAM on free HF Spaces.
6
  """
7
 
8
  import re
9
  import logging
10
  import pandas as pd
11
- import numpy as np
12
  from typing import Optional
13
 
 
 
 
 
 
 
 
 
 
 
14
  # ---------------------------------------------------------------------------
15
  # Logging
16
  # ---------------------------------------------------------------------------
@@ -22,8 +30,6 @@ logger = logging.getLogger(__name__)
22
  # Setup
23
  # ---------------------------------------------------------------------------
24
  def _ensure_nltk_stopwords() -> None:
25
- from nltk.corpus import stopwords
26
- import nltk
27
  try:
28
  stopwords.words("english")
29
  except LookupError:
@@ -39,6 +45,7 @@ def load_csv(filepath: str) -> pd.DataFrame:
39
  missing = required_cols - set(df.columns.str.lower())
40
  if missing:
41
  raise ValueError(f"CSV is missing required column(s): {missing}")
 
42
  df.columns = df.columns.str.lower()
43
  logger.info("Loaded %d rows from '%s'.", len(df), filepath)
44
  return df
@@ -48,7 +55,6 @@ def load_csv(filepath: str) -> pd.DataFrame:
48
  # Preprocessing
49
  # ---------------------------------------------------------------------------
50
  def preprocess_text(texts: pd.Series) -> list[str]:
51
- from nltk.corpus import stopwords
52
  _ensure_nltk_stopwords()
53
  stop_words = set(stopwords.words("english"))
54
 
@@ -67,10 +73,9 @@ def preprocess_text(texts: pd.Series) -> list[str]:
67
  # ---------------------------------------------------------------------------
68
  # Model Construction
69
  # ---------------------------------------------------------------------------
70
- def build_bertopic_model(embedding_model, min_topic_size: int = 5):
71
- from bertopic import BERTopic
72
- from umap import UMAP
73
- from hdbscan import HDBSCAN
74
 
75
  umap_model = UMAP(
76
  n_neighbors=15,
@@ -80,6 +85,8 @@ def build_bertopic_model(embedding_model, min_topic_size: int = 5):
80
  random_state=42,
81
  )
82
 
 
 
83
  hdbscan_model = HDBSCAN(
84
  min_cluster_size=max(min_topic_size, 5),
85
  min_samples=2,
@@ -95,7 +102,7 @@ def build_bertopic_model(embedding_model, min_topic_size: int = 5):
95
  min_topic_size=max(min_topic_size, 5),
96
  verbose=False,
97
  )
98
- logger.info("BERTopic model created (min_cluster_size=%d).", max(min_topic_size, 5))
99
  return model
100
 
101
 
@@ -110,8 +117,14 @@ def _get_cluster_sizes(topics: list[int]) -> dict[int, int]:
110
  return sizes
111
 
112
 
113
- def _split_large_cluster(topic_id, doc_indices, embeddings, topics, next_id):
114
- from sklearn.cluster import KMeans
 
 
 
 
 
 
115
  if len(doc_indices) < 4:
116
  return next_id
117
  sub_embs = embeddings[doc_indices]
@@ -119,14 +132,19 @@ def _split_large_cluster(topic_id, doc_indices, embeddings, topics, next_id):
119
  labels = km.fit_predict(sub_embs)
120
  new_id = next_id
121
  for local_idx, global_idx in enumerate(doc_indices):
122
- if labels[local_idx] == 1:
123
  topics[global_idx] = new_id
124
  logger.info("Split large cluster %d → kept %d, created %d.", topic_id, topic_id, new_id)
125
  return next_id + 1
126
 
127
 
128
- def _merge_small_cluster(topic_id, doc_indices, cluster_centroids, topics):
129
- from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
 
130
  if not cluster_centroids:
131
  return
132
  src_centroid = cluster_centroids[topic_id].reshape(1, -1)
@@ -141,9 +159,25 @@ def _merge_small_cluster(topic_id, doc_indices, cluster_centroids, topics):
141
  logger.info("Merged small cluster %d → cluster %d.", topic_id, nearest)
142
 
143
 
144
- def balance_clusters(topics, documents, embedding_model, large_factor=2.0, small_threshold=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  try:
 
146
  embeddings = embedding_model.encode(documents, show_progress_bar=False)
 
147
  topics = list(topics)
148
  sizes = _get_cluster_sizes(topics)
149
  if not sizes:
@@ -153,51 +187,67 @@ def balance_clusters(topics, documents, embedding_model, large_factor=2.0, small
153
  median_size = float(np.median(counts))
154
  large_cutoff = large_factor * median_size
155
 
 
156
  cluster_docs: dict[int, list[int]] = {}
157
  for idx, tid in enumerate(topics):
158
  if tid != -1:
159
  cluster_docs.setdefault(tid, []).append(idx)
160
 
161
- centroids = {
 
162
  tid: embeddings[idxs].mean(axis=0)
163
  for tid, idxs in cluster_docs.items()
164
  }
165
 
166
  next_id = max(sizes.keys()) + 1
167
 
 
168
  for tid, size in list(sizes.items()):
169
  if size > large_cutoff:
170
- next_id = _split_large_cluster(tid, cluster_docs[tid], embeddings, topics, next_id)
 
 
171
 
 
172
  sizes = _get_cluster_sizes(topics)
173
  cluster_docs = {}
174
  for idx, tid in enumerate(topics):
175
  if tid != -1:
176
  cluster_docs.setdefault(tid, []).append(idx)
177
 
 
178
  for tid, size in list(sizes.items()):
179
  if size < small_threshold and tid in cluster_docs:
180
  _merge_small_cluster(tid, cluster_docs[tid], centroids, topics)
181
 
182
  return topics
183
  except Exception as e:
184
- logger.error("Cluster balancing error: %s", e)
185
  raise e
186
 
187
 
188
  # ---------------------------------------------------------------------------
189
  # Topic Extraction
190
  # ---------------------------------------------------------------------------
191
- def extract_topics(model, documents, embedding_model, label="documents") -> dict:
 
 
 
 
 
 
192
  valid_docs = [d if d.strip() else "empty" for d in documents]
 
193
  topics, _ = model.fit_transform(valid_docs)
194
 
 
 
195
  try:
196
  topics = balance_clusters(topics, valid_docs, embedding_model)
197
  except Exception as e:
198
- logger.error("Cluster balancing failed (using original topics): %s", e)
199
 
200
- topic_info = model.get_topic_info()
201
 
202
  topic_keywords: dict[int, list[tuple[str, float]]] = {}
203
  for topic_id in topic_info["Topic"].tolist():
@@ -207,9 +257,16 @@ def extract_topics(model, documents, embedding_model, label="documents") -> dict
207
  if words:
208
  topic_keywords[topic_id] = words
209
 
210
- topic_freq = topic_info.set_index("Topic")["Count"].to_dict()
 
 
 
 
 
 
 
 
211
 
212
- logger.info("Extracted %d topic(s) from %s.", len(topic_keywords), label)
213
  return {
214
  "topics": topics,
215
  "topic_info": topic_info,
@@ -219,28 +276,30 @@ def extract_topics(model, documents, embedding_model, label="documents") -> dict
219
 
220
 
221
  # ---------------------------------------------------------------------------
222
- # High-Level Pipeline — ALL heavy imports live here
223
  # ---------------------------------------------------------------------------
224
- def run_topic_modeling(filepath: str, min_topic_size: int = 5) -> dict:
225
- # Heavy imports deferred to here so app.py startup stays lightweight
226
- from sentence_transformers import SentenceTransformer
227
- from bertopic import BERTopic # noqa: F401 (ensures bertopic is cached)
228
 
229
  df = load_csv(filepath)
 
230
  clean_abstracts = preprocess_text(df["abstract"])
231
  clean_titles = preprocess_text(df["title"])
232
 
 
233
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
234
 
235
  abstract_model = build_bertopic_model(embedding_model, min_topic_size=min_topic_size)
236
- title_model = build_bertopic_model(embedding_model, min_topic_size=min_topic_size)
237
 
238
  abstract_results = extract_topics(abstract_model, clean_abstracts, embedding_model, label="abstracts")
239
- title_results = extract_topics(title_model, clean_titles, embedding_model, label="titles")
240
 
241
  return {
242
  "abstracts": abstract_results,
243
- "titles": title_results,
244
  }
245
 
246
 
@@ -253,15 +312,15 @@ def print_results(results: dict, top_n_keywords: int = 10) -> None:
253
  print(f" Topic Modeling Results – {section.upper()}")
254
  print(f"{'='*60}")
255
 
256
- keywords = data["topic_keywords"]
257
- freq = data["topic_freq"]
258
 
259
  if not keywords:
260
  print(" No topics found.")
261
  continue
262
 
263
  for topic_id, words in sorted(keywords.items()):
264
- count = freq.get(topic_id, 0)
265
  kw_str = ", ".join(w for w, _ in words[:top_n_keywords])
266
  print(f"\n Topic {topic_id:>3} | docs: {count:>4}")
267
  print(f" Keywords : {kw_str}")
@@ -276,10 +335,13 @@ def print_results(results: dict, top_n_keywords: int = 10) -> None:
276
  # ---------------------------------------------------------------------------
277
  if __name__ == "__main__":
278
  import sys
 
279
  if len(sys.argv) < 2:
280
  print("Usage: python tools.py <path_to_csv> [min_topic_size]")
281
  sys.exit(1)
 
282
  csv_path = sys.argv[1]
283
  mts = int(sys.argv[2]) if len(sys.argv) > 2 else 5
 
284
  output = run_topic_modeling(csv_path, min_topic_size=mts)
285
- print_results(output)
 
2
  tools.py
3
  --------
4
  Topic modeling module using BERTopic for analyzing research paper abstracts and titles.
 
5
  """
6
 
7
  import re
8
  import logging
9
  import pandas as pd
 
10
  from typing import Optional
11
 
12
+ from bertopic import BERTopic
13
+ from sentence_transformers import SentenceTransformer
14
+ from umap import UMAP
15
+ from hdbscan import HDBSCAN # --- Cluster Balancing Logic ---
16
+ from sklearn.cluster import KMeans
17
+ from sklearn.metrics.pairwise import cosine_similarity
18
+ import numpy as np
19
+ from nltk.corpus import stopwords
20
+ import nltk
21
+
22
  # ---------------------------------------------------------------------------
23
  # Logging
24
  # ---------------------------------------------------------------------------
 
30
  # Setup
31
  # ---------------------------------------------------------------------------
32
  def _ensure_nltk_stopwords() -> None:
 
 
33
  try:
34
  stopwords.words("english")
35
  except LookupError:
 
45
  missing = required_cols - set(df.columns.str.lower())
46
  if missing:
47
  raise ValueError(f"CSV is missing required column(s): {missing}")
48
+
49
  df.columns = df.columns.str.lower()
50
  logger.info("Loaded %d rows from '%s'.", len(df), filepath)
51
  return df
 
55
  # Preprocessing
56
  # ---------------------------------------------------------------------------
57
  def preprocess_text(texts: pd.Series) -> list[str]:
 
58
  _ensure_nltk_stopwords()
59
  stop_words = set(stopwords.words("english"))
60
 
 
73
  # ---------------------------------------------------------------------------
74
  # Model Construction
75
  # ---------------------------------------------------------------------------
76
+ def build_bertopic_model(embedding_model: SentenceTransformer, min_topic_size: int = 5) -> BERTopic:
77
+ # --- Cluster Balancing Logic ---
78
+ # (embedding_model is passed explicitly from run_topic_modeling)
 
79
 
80
  umap_model = UMAP(
81
  n_neighbors=15,
 
85
  random_state=42,
86
  )
87
 
88
+ # Tuned HDBSCAN: smaller min_cluster_size allows more granular clusters;
89
+ # reduced min_samples makes the model less strict about noise.
90
  hdbscan_model = HDBSCAN(
91
  min_cluster_size=max(min_topic_size, 5),
92
  min_samples=2,
 
102
  min_topic_size=max(min_topic_size, 5),
103
  verbose=False,
104
  )
105
+ logger.info("BERTopic model created with tuned HDBSCAN (min_cluster_size=%d).", max(min_topic_size, 5))
106
  return model
107
 
108
 
 
117
  return sizes
118
 
119
 
120
+ def _split_large_cluster(
121
+ topic_id: int,
122
+ doc_indices: list[int],
123
+ embeddings: np.ndarray,
124
+ topics: list[int],
125
+ next_id: int,
126
+ ) -> int:
127
+ """Split an oversized cluster into 2 sub-clusters via KMeans. Returns next available ID."""
128
  if len(doc_indices) < 4:
129
  return next_id
130
  sub_embs = embeddings[doc_indices]
 
132
  labels = km.fit_predict(sub_embs)
133
  new_id = next_id
134
  for local_idx, global_idx in enumerate(doc_indices):
135
+ if labels[local_idx] == 1: # half goes to a new cluster ID
136
  topics[global_idx] = new_id
137
  logger.info("Split large cluster %d → kept %d, created %d.", topic_id, topic_id, new_id)
138
  return next_id + 1
139
 
140
 
141
+ def _merge_small_cluster(
142
+ topic_id: int,
143
+ doc_indices: list[int],
144
+ cluster_centroids: dict[int, np.ndarray],
145
+ topics: list[int],
146
+ ) -> None:
147
+ """Merge a tiny cluster into the nearest cluster by cosine similarity."""
148
  if not cluster_centroids:
149
  return
150
  src_centroid = cluster_centroids[topic_id].reshape(1, -1)
 
159
  logger.info("Merged small cluster %d → cluster %d.", topic_id, nearest)
160
 
161
 
162
+ def balance_clusters(
163
+ topics: list[int],
164
+ documents: list[str],
165
+ embedding_model: SentenceTransformer,
166
+ large_factor: float = 2.0,
167
+ small_threshold: int = 3,
168
+ ) -> list[int]:
169
+ """
170
+ --- Cluster Balancing Logic ---
171
+ Post-process HDBSCAN topic assignments to reduce extreme size imbalance.
172
+
173
+ - Splits clusters > large_factor × median size (via KMeans sub-split).
174
+ - Merges clusters < small_threshold into their nearest neighbour.
175
+ Does NOT enforce equal sizes.
176
+ """
177
  try:
178
+ # Ensure balance_clusters actually runs and uses embedding_model.encode
179
  embeddings = embedding_model.encode(documents, show_progress_bar=False)
180
+
181
  topics = list(topics)
182
  sizes = _get_cluster_sizes(topics)
183
  if not sizes:
 
187
  median_size = float(np.median(counts))
188
  large_cutoff = large_factor * median_size
189
 
190
+ # Build per-cluster document index lists
191
  cluster_docs: dict[int, list[int]] = {}
192
  for idx, tid in enumerate(topics):
193
  if tid != -1:
194
  cluster_docs.setdefault(tid, []).append(idx)
195
 
196
+ # Compute centroids for merge step
197
+ centroids: dict[int, np.ndarray] = {
198
  tid: embeddings[idxs].mean(axis=0)
199
  for tid, idxs in cluster_docs.items()
200
  }
201
 
202
  next_id = max(sizes.keys()) + 1
203
 
204
+ # Split oversized clusters
205
  for tid, size in list(sizes.items()):
206
  if size > large_cutoff:
207
+ next_id = _split_large_cluster(
208
+ tid, cluster_docs[tid], embeddings, topics, next_id
209
+ )
210
 
211
+ # Re-compute sizes after splits for merge step
212
  sizes = _get_cluster_sizes(topics)
213
  cluster_docs = {}
214
  for idx, tid in enumerate(topics):
215
  if tid != -1:
216
  cluster_docs.setdefault(tid, []).append(idx)
217
 
218
+ # Merge undersized clusters
219
  for tid, size in list(sizes.items()):
220
  if size < small_threshold and tid in cluster_docs:
221
  _merge_small_cluster(tid, cluster_docs[tid], centroids, topics)
222
 
223
  return topics
224
  except Exception as e:
225
+ print("Cluster balancing error:", e)
226
  raise e
227
 
228
 
229
  # ---------------------------------------------------------------------------
230
  # Topic Extraction
231
  # ---------------------------------------------------------------------------
232
+ def extract_topics(
233
+ model: BERTopic,
234
+ documents: list[str],
235
+ embedding_model: SentenceTransformer,
236
+ label: str = "documents",
237
+ ) -> dict:
238
+
239
  valid_docs = [d if d.strip() else "empty" for d in documents]
240
+
241
  topics, _ = model.fit_transform(valid_docs)
242
 
243
+ # --- Cluster Balancing Logic ---
244
+ # Attempt to balance clusters but move ahead if it fails
245
  try:
246
  topics = balance_clusters(topics, valid_docs, embedding_model)
247
  except Exception as e:
248
+ logger.error("Cluster balancing failed (moving ahead with original topics): %s", e)
249
 
250
+ topic_info: pd.DataFrame = model.get_topic_info()
251
 
252
  topic_keywords: dict[int, list[tuple[str, float]]] = {}
253
  for topic_id in topic_info["Topic"].tolist():
 
257
  if words:
258
  topic_keywords[topic_id] = words
259
 
260
+ topic_freq: dict[int, int] = (
261
+ topic_info.set_index("Topic")["Count"].to_dict()
262
+ )
263
+
264
+ logger.info(
265
+ "Extracted %d topic(s) from %s.",
266
+ len(topic_keywords),
267
+ label,
268
+ )
269
 
 
270
  return {
271
  "topics": topics,
272
  "topic_info": topic_info,
 
276
 
277
 
278
  # ---------------------------------------------------------------------------
279
+ # High-Level Pipeline
280
  # ---------------------------------------------------------------------------
281
+ def run_topic_modeling(
282
+ filepath: str,
283
+ min_topic_size: int = 5,
284
+ ) -> dict:
285
 
286
  df = load_csv(filepath)
287
+
288
  clean_abstracts = preprocess_text(df["abstract"])
289
  clean_titles = preprocess_text(df["title"])
290
 
291
+ # Create embedding model once to be shared across steps
292
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
293
 
294
  abstract_model = build_bertopic_model(embedding_model, min_topic_size=min_topic_size)
295
+ title_model = build_bertopic_model(embedding_model, min_topic_size=min_topic_size)
296
 
297
  abstract_results = extract_topics(abstract_model, clean_abstracts, embedding_model, label="abstracts")
298
+ title_results = extract_topics(title_model, clean_titles, embedding_model, label="titles")
299
 
300
  return {
301
  "abstracts": abstract_results,
302
+ "titles": title_results,
303
  }
304
 
305
 
 
312
  print(f" Topic Modeling Results – {section.upper()}")
313
  print(f"{'='*60}")
314
 
315
+ keywords: dict = data["topic_keywords"]
316
+ freq: dict = data["topic_freq"]
317
 
318
  if not keywords:
319
  print(" No topics found.")
320
  continue
321
 
322
  for topic_id, words in sorted(keywords.items()):
323
+ count = freq.get(topic_id, 0)
324
  kw_str = ", ".join(w for w, _ in words[:top_n_keywords])
325
  print(f"\n Topic {topic_id:>3} | docs: {count:>4}")
326
  print(f" Keywords : {kw_str}")
 
335
  # ---------------------------------------------------------------------------
336
  if __name__ == "__main__":
337
  import sys
338
+
339
  if len(sys.argv) < 2:
340
  print("Usage: python tools.py <path_to_csv> [min_topic_size]")
341
  sys.exit(1)
342
+
343
  csv_path = sys.argv[1]
344
  mts = int(sys.argv[2]) if len(sys.argv) > 2 else 5
345
+
346
  output = run_topic_modeling(csv_path, min_topic_size=mts)
347
+ print_results(output)