anujjuna commited on
Commit
7cbf97d
·
verified ·
1 Parent(s): a0a0a64

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +81 -19
tools.py CHANGED
@@ -18,7 +18,6 @@ from collections import Counter, defaultdict
18
  from sentence_transformers import SentenceTransformer
19
  from umap import UMAP
20
  from hdbscan import HDBSCAN
21
- from keybert import KeyBERT
22
  from sklearn.metrics import adjusted_rand_score
23
  from sklearn.metrics.pairwise import cosine_similarity
24
  import optuna
@@ -56,15 +55,20 @@ def prepare_documents(df: pd.DataFrame) -> list[str]:
56
 
57
 
58
  # ---------------------------------------------------------------------------
59
- # §3.1 — Embed with SPECTER-2
60
  # ---------------------------------------------------------------------------
 
 
61
  def embed_documents(
62
  docs: list[str],
63
  model_name: str = "allenai/specter2_base",
64
  ) -> np.ndarray:
65
  """Embed with SPECTER-2. Deterministic — no tuning (§3.3)."""
66
- model = SentenceTransformer(model_name)
67
- embeddings = model.encode(docs, show_progress_bar=True, batch_size=32)
 
 
 
68
  logger.info("Embedded %d docs → %s", len(docs), embeddings.shape)
69
  return embeddings
70
 
@@ -110,6 +114,19 @@ def compute_persistence(clusterer: HDBSCAN) -> float:
110
  return 0.0
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def compute_dbcv(reduced: np.ndarray, labels: np.ndarray) -> float:
114
  """Density-Based Cluster Validity index."""
115
  try:
@@ -124,14 +141,15 @@ def compute_dbcv(reduced: np.ndarray, labels: np.ndarray) -> float:
124
 
125
 
126
  def compute_stability(embeddings: np.ndarray, params: dict,
127
- n_seeds: int = 5) -> float:
128
- """Cluster-recurrence stability via pairwise ARI across seeds (§3.4)."""
 
129
  all_labels = []
130
  for s in range(n_seeds):
131
  u = UMAP(n_neighbors=params["n_neighbors"],
132
  n_components=params["n_components"],
133
  min_dist=0.0, metric="cosine",
134
- random_state=s * 7 + 1)
135
  red = u.fit_transform(embeddings)
136
  h = HDBSCAN(min_cluster_size=params["min_cluster_size"],
137
  min_samples=params["min_samples"],
@@ -167,7 +185,8 @@ def _objective(trial, embeddings, n_docs):
167
  min_cluster_size=mcs, min_samples=ms, csm=csm, cse=cse)
168
 
169
  u = UMAP(n_neighbors=n_neighbors, n_components=n_components,
170
- min_dist=0.0, metric="cosine", random_state=42)
 
171
  red = u.fit_transform(embeddings)
172
 
173
  h = HDBSCAN(min_cluster_size=mcs, min_samples=ms, metric="euclidean",
@@ -235,7 +254,7 @@ def run_bayesian_optimisation(
235
  )
236
  # §3.6 convergence: 3 consecutive passing within 5 % of best
237
  passing = [e for e in trial_log if e["discipline_pass"]]
238
- if len(passing) >= 3 and i >= 19:
239
  best_p = max(e["persistence"] for e in passing)
240
  if best_p > 0:
241
  last3 = passing[-3:]
@@ -255,7 +274,7 @@ def run_bayesian_optimisation(
255
 
256
  bp = best.user_attrs["params"]
257
  labels = np.array(best.user_attrs["labels"])
258
- stability = compute_stability(embeddings, bp, n_seeds=5)
259
 
260
  return dict(
261
  best_params=bp, best_labels=labels,
@@ -274,28 +293,37 @@ def run_bayesian_optimisation(
274
  # ---------------------------------------------------------------------------
275
  def compute_2d_umap(embeddings: np.ndarray, seed: int = 42) -> np.ndarray:
276
  return UMAP(n_neighbors=15, n_components=2, min_dist=0.1,
277
- metric="cosine", random_state=seed).fit_transform(embeddings)
 
278
 
279
 
280
  # ---------------------------------------------------------------------------
281
- # §3.1 — KeyBERT keyphrase extraction per cluster (3–5 phrases)
 
282
  # ---------------------------------------------------------------------------
283
  def extract_keyphrases(docs: list[str], labels: np.ndarray,
284
  top_n: int = 5) -> dict:
285
- kw = KeyBERT(model="all-MiniLM-L6-v2")
286
  cluster_docs = defaultdict(list)
287
  for doc, lab in zip(docs, labels):
288
  if lab != -1:
289
  cluster_docs[int(lab)].append(doc)
290
  out = {}
291
  for cid, cdocs in cluster_docs.items():
 
 
 
292
  try:
293
- out[cid] = kw.extract_keywords(
294
- " ".join(cdocs), keyphrase_ngram_range=(1, 3),
295
- stop_words="english", top_n=top_n,
296
- use_mmr=True, diversity=0.5)
 
 
 
 
297
  except Exception as e:
298
- logger.warning("KeyBERT cluster %d: %s", cid, e)
299
  out[cid] = []
300
  return out
301
 
@@ -364,6 +392,35 @@ def get_representative_docs(labels, embeddings, docs, top_n=3):
364
  return out
365
 
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  # ---------------------------------------------------------------------------
368
  # High-level pipeline entry point
369
  # ---------------------------------------------------------------------------
@@ -390,9 +447,13 @@ def run_topic_modeling(filepath: str, n_trials: int = 50,
390
  min_samples=bp["min_samples"], metric="euclidean",
391
  cluster_selection_method=bp["csm"],
392
  cluster_selection_epsilon=bp["cse"],
393
- allow_single_cluster=False)
 
394
  h.fit(red)
395
 
 
 
 
396
  # 5. Outlier reduction (§3.2 — clusters < 5 reassigned)
397
  labels = outlier_reduction(labels, red, n_docs)
398
 
@@ -416,6 +477,7 @@ def run_topic_modeling(filepath: str, n_trials: int = 50,
416
  keyphrases=keyphrases, representative_docs=rep_docs,
417
  membership=sw, umap_2d=umap_2d.tolist(),
418
  discipline=disc, best_params=bp,
 
419
  metrics=dict(persistence=opt["persistence"],
420
  dbcv=opt["dbcv"],
421
  stability=opt["stability"]),
 
18
  from sentence_transformers import SentenceTransformer
19
  from umap import UMAP
20
  from hdbscan import HDBSCAN
 
21
  from sklearn.metrics import adjusted_rand_score
22
  from sklearn.metrics.pairwise import cosine_similarity
23
  import optuna
 
55
 
56
 
57
  # ---------------------------------------------------------------------------
58
+ # §3.1 — Embed with SPECTER-2 (cached model for speed)
59
  # ---------------------------------------------------------------------------
60
+ _MODEL_CACHE = {}
61
+
62
  def embed_documents(
63
  docs: list[str],
64
  model_name: str = "allenai/specter2_base",
65
  ) -> np.ndarray:
66
  """Embed with SPECTER-2. Deterministic — no tuning (§3.3)."""
67
+ if model_name not in _MODEL_CACHE:
68
+ logger.info("Loading %s (first time, will be cached)…", model_name)
69
+ _MODEL_CACHE[model_name] = SentenceTransformer(model_name)
70
+ model = _MODEL_CACHE[model_name]
71
+ embeddings = model.encode(docs, show_progress_bar=True, batch_size=64)
72
  logger.info("Embedded %d docs → %s", len(docs), embeddings.shape)
73
  return embeddings
74
 
 
114
  return 0.0
115
 
116
 
117
+ def per_cluster_persistence(clusterer: HDBSCAN, labels: np.ndarray) -> dict:
118
+ """Map each cluster ID to its persistence score (§8)."""
119
+ try:
120
+ p = getattr(clusterer, "cluster_persistence_", None)
121
+ if p is None or len(p) == 0:
122
+ return {}
123
+ unique = sorted(set(int(l) for l in labels if l != -1))
124
+ return {cid: float(p[i]) if i < len(p) else 0.0
125
+ for i, cid in enumerate(unique)}
126
+ except Exception:
127
+ return {}
128
+
129
+
130
  def compute_dbcv(reduced: np.ndarray, labels: np.ndarray) -> float:
131
  """Density-Based Cluster Validity index."""
132
  try:
 
141
 
142
 
143
  def compute_stability(embeddings: np.ndarray, params: dict,
144
+ n_seeds: int = 3) -> float:
145
+ """Cluster-recurrence stability via pairwise ARI across seeds (§3.4).
146
+ Uses 3 seeds by default for speed (spec allows 3–5)."""
147
  all_labels = []
148
  for s in range(n_seeds):
149
  u = UMAP(n_neighbors=params["n_neighbors"],
150
  n_components=params["n_components"],
151
  min_dist=0.0, metric="cosine",
152
+ random_state=s * 7 + 1, low_memory=True)
153
  red = u.fit_transform(embeddings)
154
  h = HDBSCAN(min_cluster_size=params["min_cluster_size"],
155
  min_samples=params["min_samples"],
 
185
  min_cluster_size=mcs, min_samples=ms, csm=csm, cse=cse)
186
 
187
  u = UMAP(n_neighbors=n_neighbors, n_components=n_components,
188
+ min_dist=0.0, metric="cosine", random_state=42,
189
+ low_memory=True)
190
  red = u.fit_transform(embeddings)
191
 
192
  h = HDBSCAN(min_cluster_size=mcs, min_samples=ms, metric="euclidean",
 
254
  )
255
  # §3.6 convergence: 3 consecutive passing within 5 % of best
256
  passing = [e for e in trial_log if e["discipline_pass"]]
257
+ if len(passing) >= 3 and i >= 9: # allow early stop after 10 trials
258
  best_p = max(e["persistence"] for e in passing)
259
  if best_p > 0:
260
  last3 = passing[-3:]
 
274
 
275
  bp = best.user_attrs["params"]
276
  labels = np.array(best.user_attrs["labels"])
277
+ stability = compute_stability(embeddings, bp, n_seeds=3)
278
 
279
  return dict(
280
  best_params=bp, best_labels=labels,
 
293
  # ---------------------------------------------------------------------------
294
  def compute_2d_umap(embeddings: np.ndarray, seed: int = 42) -> np.ndarray:
295
  return UMAP(n_neighbors=15, n_components=2, min_dist=0.1,
296
+ metric="cosine", random_state=seed,
297
+ low_memory=True).fit_transform(embeddings)
298
 
299
 
300
  # ---------------------------------------------------------------------------
301
+ # §3.1 — TF-IDF keyphrase extraction per cluster (3–5 phrases)
302
+ # Fast alternative to KeyBERT — no extra model download needed.
303
  # ---------------------------------------------------------------------------
304
  def extract_keyphrases(docs: list[str], labels: np.ndarray,
305
  top_n: int = 5) -> dict:
306
+ from sklearn.feature_extraction.text import TfidfVectorizer
307
  cluster_docs = defaultdict(list)
308
  for doc, lab in zip(docs, labels):
309
  if lab != -1:
310
  cluster_docs[int(lab)].append(doc)
311
  out = {}
312
  for cid, cdocs in cluster_docs.items():
313
+ if len(cdocs) < 2:
314
+ out[cid] = []
315
+ continue
316
  try:
317
+ tfidf = TfidfVectorizer(
318
+ stop_words="english", max_features=200,
319
+ ngram_range=(1, 3), max_df=0.9, min_df=1)
320
+ X = tfidf.fit_transform(cdocs)
321
+ terms = tfidf.get_feature_names_out()
322
+ scores = X.sum(axis=0).A1
323
+ top_idx = scores.argsort()[::-1][:top_n]
324
+ out[cid] = [(terms[i], float(scores[i])) for i in top_idx]
325
  except Exception as e:
326
+ logger.warning("Keyphrase extraction cluster %d: %s", cid, e)
327
  out[cid] = []
328
  return out
329
 
 
392
  return out
393
 
394
 
395
+ # ---------------------------------------------------------------------------
396
+ # §9 — RQ2 / RQ3 mismatch table
397
+ # ---------------------------------------------------------------------------
398
+ def build_mismatch_table(keyphrases: dict, cluster_labels: dict) -> list:
399
+ """Compare cluster keyphrases against assigned labels to flag mismatches.
400
+ Returns rows for a mismatch table (§9)."""
401
+ rows = []
402
+ for cid in sorted(keyphrases.keys()):
403
+ kps = keyphrases.get(cid, [])
404
+ kp_terms = [k[0] if isinstance(k, tuple) else k for k in kps[:5]]
405
+ label = cluster_labels.get(cid, "")
406
+ # Check overlap between label words and keyphrase terms
407
+ label_words = set(label.lower().split())
408
+ kp_words = set(" ".join(kp_terms).lower().split())
409
+ overlap = label_words & kp_words
410
+ noise = {"the","and","for","with","using","based","from","in","of","a","to"}
411
+ overlap -= noise
412
+ match_pct = len(overlap) / max(len(label_words - noise), 1)
413
+ status = "MATCH" if match_pct >= 0.3 else "MISMATCH"
414
+ rows.append(dict(
415
+ cluster=cid, label=label,
416
+ keyphrases=", ".join(kp_terms),
417
+ overlap=", ".join(overlap) if overlap else "—",
418
+ match_pct=round(match_pct * 100),
419
+ status=status,
420
+ ))
421
+ return rows
422
+
423
+
424
  # ---------------------------------------------------------------------------
425
  # High-level pipeline entry point
426
  # ---------------------------------------------------------------------------
 
447
  min_samples=bp["min_samples"], metric="euclidean",
448
  cluster_selection_method=bp["csm"],
449
  cluster_selection_epsilon=bp["cse"],
450
+ allow_single_cluster=False,
451
+ gen_min_span_tree=True)
452
  h.fit(red)
453
 
454
+ # Per-cluster persistence (§8)
455
+ cluster_pers = per_cluster_persistence(h, labels)
456
+
457
  # 5. Outlier reduction (§3.2 — clusters < 5 reassigned)
458
  labels = outlier_reduction(labels, red, n_docs)
459
 
 
477
  keyphrases=keyphrases, representative_docs=rep_docs,
478
  membership=sw, umap_2d=umap_2d.tolist(),
479
  discipline=disc, best_params=bp,
480
+ cluster_persistence=cluster_pers,
481
  metrics=dict(persistence=opt["persistence"],
482
  dbcv=opt["dbcv"],
483
  stability=opt["stability"]),