Spaces:
Sleeping
Sleeping
Update tools.py
Browse files
tools.py
CHANGED
|
@@ -1,522 +1,427 @@
|
|
| 1 |
"""
|
| 2 |
tools.py
|
| 3 |
--------
|
| 4 |
-
Topic
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import re
|
| 8 |
import logging
|
| 9 |
import pandas as pd
|
|
|
|
| 10 |
from typing import Optional
|
|
|
|
| 11 |
|
| 12 |
-
from bertopic import BERTopic
|
| 13 |
from sentence_transformers import SentenceTransformer
|
| 14 |
from umap import UMAP
|
| 15 |
-
from hdbscan import HDBSCAN
|
| 16 |
-
from
|
|
|
|
| 17 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 18 |
-
import
|
| 19 |
-
from nltk.corpus import stopwords
|
| 20 |
-
import nltk
|
| 21 |
-
from sklearn.feature_extraction.text import CountVectorizer
|
| 22 |
-
from collections import defaultdict, Counter
|
| 23 |
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
# Logging
|
| 26 |
# ---------------------------------------------------------------------------
|
| 27 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
|
| 28 |
logger = logging.getLogger(__name__)
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
# ---------------------------------------------------------------------------
|
| 32 |
-
#
|
| 33 |
-
# ---------------------------------------------------------------------------
|
| 34 |
-
def _ensure_nltk_stopwords() -> None:
|
| 35 |
-
try:
|
| 36 |
-
stopwords.words("english")
|
| 37 |
-
except LookupError:
|
| 38 |
-
nltk.download("stopwords", quiet=True)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# ---------------------------------------------------------------------------
|
| 42 |
-
# Data Loading
|
| 43 |
# ---------------------------------------------------------------------------
|
| 44 |
def load_csv(filepath: str) -> pd.DataFrame:
|
| 45 |
df = pd.read_csv(filepath)
|
| 46 |
-
required_cols = {"title", "abstract"}
|
| 47 |
-
missing = required_cols - set(df.columns.str.lower())
|
| 48 |
-
if missing:
|
| 49 |
-
raise ValueError(f"CSV is missing required column(s): {missing}")
|
| 50 |
-
|
| 51 |
df.columns = df.columns.str.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
logger.info("Loaded %d rows from '%s'.", len(df), filepath)
|
| 53 |
return df
|
| 54 |
|
| 55 |
|
| 56 |
# ---------------------------------------------------------------------------
|
| 57 |
-
#
|
| 58 |
# ---------------------------------------------------------------------------
|
| 59 |
-
def
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
for raw in texts.fillna(""):
|
| 65 |
-
text = raw.lower()
|
| 66 |
-
text = re.sub(r"[^a-z\s]", " ", text)
|
| 67 |
-
tokens = text.split()
|
| 68 |
-
tokens = [t for t in tokens if t not in stop_words and len(t) > 1]
|
| 69 |
-
cleaned.append(" ".join(tokens))
|
| 70 |
-
|
| 71 |
-
logger.info("Preprocessed %d documents.", len(cleaned))
|
| 72 |
-
return cleaned
|
| 73 |
|
| 74 |
|
| 75 |
# ---------------------------------------------------------------------------
|
| 76 |
-
#
|
| 77 |
# ---------------------------------------------------------------------------
|
| 78 |
-
def
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
random_state=42,
|
| 88 |
-
)
|
| 89 |
|
| 90 |
-
# Updated HDBSCAN constraints
|
| 91 |
-
hdbscan_model = HDBSCAN(
|
| 92 |
-
min_cluster_size=5,
|
| 93 |
-
min_samples=3,
|
| 94 |
-
metric="euclidean",
|
| 95 |
-
cluster_selection_method="eom",
|
| 96 |
-
prediction_data=True,
|
| 97 |
-
)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
)
|
| 106 |
-
logger.info("BERTopic model created with HDBSCAN (min_cluster_size=5, min_samples=3).")
|
| 107 |
-
return model
|
| 108 |
|
| 109 |
|
| 110 |
# ---------------------------------------------------------------------------
|
| 111 |
-
#
|
| 112 |
# ---------------------------------------------------------------------------
|
| 113 |
-
def
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
-
def
|
| 122 |
-
|
| 123 |
-
doc_indices: list[int],
|
| 124 |
-
embeddings: np.ndarray,
|
| 125 |
-
topics: list[int],
|
| 126 |
-
next_id: int,
|
| 127 |
-
) -> int:
|
| 128 |
-
"""Split an oversized cluster into 2 sub-clusters via KMeans. Returns next available ID."""
|
| 129 |
-
if len(doc_indices) < 10: # Minimum threshold to split
|
| 130 |
-
return next_id
|
| 131 |
-
sub_embs = embeddings[doc_indices]
|
| 132 |
-
km = KMeans(n_clusters=2, random_state=42, n_init=5)
|
| 133 |
-
labels = km.fit_predict(sub_embs)
|
| 134 |
-
new_id = next_id
|
| 135 |
-
for local_idx, global_idx in enumerate(doc_indices):
|
| 136 |
-
if labels[local_idx] == 1: # half goes to a new cluster ID
|
| 137 |
-
topics[global_idx] = new_id
|
| 138 |
-
logger.info("Split large cluster %d → kept %d, created %d.", topic_id, topic_id, new_id)
|
| 139 |
-
return next_id + 1
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def _merge_small_cluster(
|
| 143 |
-
topic_id: int,
|
| 144 |
-
doc_indices: list[int],
|
| 145 |
-
cluster_centroids: dict[int, np.ndarray],
|
| 146 |
-
topics: list[int],
|
| 147 |
-
similarity_threshold: float = 0.5,
|
| 148 |
-
) -> bool:
|
| 149 |
-
"""Merge a tiny cluster into the nearest cluster by cosine similarity if threshold met."""
|
| 150 |
-
if not cluster_centroids or topic_id not in cluster_centroids:
|
| 151 |
-
return False
|
| 152 |
-
src_centroid = cluster_centroids[topic_id].reshape(1, -1)
|
| 153 |
-
other_ids = [tid for tid in cluster_centroids if tid != topic_id]
|
| 154 |
-
if not other_ids:
|
| 155 |
-
return False
|
| 156 |
-
other_centroids = np.vstack([cluster_centroids[tid] for tid in other_ids])
|
| 157 |
-
sims = cosine_similarity(src_centroid, other_centroids)[0]
|
| 158 |
-
best_idx = int(np.argmax(sims))
|
| 159 |
-
max_sim = sims[best_idx]
|
| 160 |
-
|
| 161 |
-
if max_sim >= similarity_threshold:
|
| 162 |
-
nearest = other_ids[best_idx]
|
| 163 |
-
for idx in doc_indices:
|
| 164 |
-
topics[idx] = nearest
|
| 165 |
-
logger.info("Merged small cluster %d → cluster %d (sim=%.2f).", topic_id, nearest, max_sim)
|
| 166 |
-
return True
|
| 167 |
-
return False
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def balance_clusters(
|
| 171 |
-
topics: list[int],
|
| 172 |
-
documents: list[str],
|
| 173 |
-
embedding_model: SentenceTransformer,
|
| 174 |
-
embeddings: Optional[np.ndarray] = None,
|
| 175 |
-
) -> list[int]:
|
| 176 |
-
"""
|
| 177 |
-
Enforce cluster size limits: MIN=5, MAX=30.
|
| 178 |
-
"""
|
| 179 |
try:
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
topics = list(topics)
|
| 184 |
-
MIN_CLUSTER_SIZE = 5
|
| 185 |
-
MAX_CLUSTER_SIZE = 30
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
if tid != -1:
|
| 195 |
-
cluster_docs.setdefault(tid, []).append(idx)
|
| 196 |
-
|
| 197 |
-
centroids: dict[int, np.ndarray] = {
|
| 198 |
-
tid: embeddings[idxs].mean(axis=0)
|
| 199 |
-
for tid, idxs in cluster_docs.items()
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
next_id = max(sizes.keys()) + 1 if sizes else 0
|
| 203 |
-
changed = False
|
| 204 |
-
|
| 205 |
-
# Split oversized clusters
|
| 206 |
-
for tid, size in list(sizes.items()):
|
| 207 |
-
if size > MAX_CLUSTER_SIZE:
|
| 208 |
-
old_next_id = next_id
|
| 209 |
-
next_id = _split_large_cluster(
|
| 210 |
-
tid, cluster_docs[tid], embeddings, topics, next_id
|
| 211 |
-
)
|
| 212 |
-
if next_id > old_next_id:
|
| 213 |
-
changed = True
|
| 214 |
-
|
| 215 |
-
# Merge undersized clusters
|
| 216 |
-
sizes = _get_cluster_sizes(topics)
|
| 217 |
-
for tid, size in list(sizes.items()):
|
| 218 |
-
if size < MIN_CLUSTER_SIZE and tid in cluster_docs:
|
| 219 |
-
if _merge_small_cluster(tid, cluster_docs[tid], centroids, topics, similarity_threshold=0.5):
|
| 220 |
-
changed = True
|
| 221 |
-
|
| 222 |
-
if not changed:
|
| 223 |
-
break
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
return topics
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
max_clusters: int = 30,
|
| 236 |
-
) -> list[int]:
|
| 237 |
-
"""Iteratively split or merge to keep total clusters between 15 and 30."""
|
| 238 |
-
topics = list(topics)
|
| 239 |
-
|
| 240 |
-
while True:
|
| 241 |
-
unique_clusters = [t for t in set(topics) if t != -1]
|
| 242 |
-
count = len(unique_clusters)
|
| 243 |
-
|
| 244 |
-
if min_clusters <= count <= max_clusters:
|
| 245 |
-
break
|
| 246 |
-
|
| 247 |
-
cluster_docs: dict[int, list[int]] = {}
|
| 248 |
-
for idx, tid in enumerate(topics):
|
| 249 |
-
if tid != -1:
|
| 250 |
-
cluster_docs.setdefault(tid, []).append(idx)
|
| 251 |
-
|
| 252 |
-
if not cluster_docs:
|
| 253 |
-
break
|
| 254 |
-
|
| 255 |
-
centroids: dict[int, np.ndarray] = {
|
| 256 |
-
tid: embeddings[idxs].mean(axis=0)
|
| 257 |
-
for tid, idxs in cluster_docs.items()
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
if count > max_clusters:
|
| 261 |
-
# Merge two closest clusters
|
| 262 |
-
ids = list(centroids.keys())
|
| 263 |
-
c_matrix = np.vstack([centroids[tid] for tid in ids])
|
| 264 |
-
sim_matrix = cosine_similarity(c_matrix)
|
| 265 |
-
np.fill_diagonal(sim_matrix, -1)
|
| 266 |
-
|
| 267 |
-
i, j = np.unravel_index(np.argmax(sim_matrix), sim_matrix.shape)
|
| 268 |
-
tid_i, tid_j = ids[i], ids[j]
|
| 269 |
-
|
| 270 |
-
# Merge tid_i into tid_j
|
| 271 |
-
for idx in cluster_docs[tid_i]:
|
| 272 |
-
topics[idx] = tid_j
|
| 273 |
-
logger.info("Reduced clusters: Merged %d and %d (count: %d -> %d)", tid_i, tid_j, count, count-1)
|
| 274 |
-
|
| 275 |
-
elif count < min_clusters:
|
| 276 |
-
# Split largest cluster
|
| 277 |
-
sizes = _get_cluster_sizes(topics)
|
| 278 |
-
largest_tid = max(sizes, key=sizes.get)
|
| 279 |
-
next_id = max(unique_clusters) + 1
|
| 280 |
-
_split_large_cluster(largest_tid, cluster_docs[largest_tid], embeddings, topics, next_id)
|
| 281 |
-
logger.info("Increased clusters: Split %d (count: %d -> %d)", largest_tid, count, count+1)
|
| 282 |
-
|
| 283 |
-
final_count = len([t for t in set(topics) if t != -1])
|
| 284 |
-
logger.info("Final cluster count: %d", final_count)
|
| 285 |
-
print(f"Final cluster count: {final_count}")
|
| 286 |
-
|
| 287 |
-
return topics
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
def get_top_3_central_docs(
|
| 291 |
-
topics: list[int],
|
| 292 |
-
embeddings: np.ndarray,
|
| 293 |
-
documents: list[str],
|
| 294 |
-
) -> dict[int, list[str]]:
|
| 295 |
-
"""Select top 3 documents closest to centroid for each topic."""
|
| 296 |
-
cluster_docs_idx: dict[int, list[int]] = {}
|
| 297 |
-
for idx, tid in enumerate(topics):
|
| 298 |
-
if tid != -1:
|
| 299 |
-
cluster_docs_idx.setdefault(tid, []).append(idx)
|
| 300 |
-
|
| 301 |
-
representative_docs = {}
|
| 302 |
-
for tid, idxs in cluster_docs_idx.items():
|
| 303 |
-
cluster_embs = embeddings[idxs]
|
| 304 |
-
centroid = cluster_embs.mean(axis=0).reshape(1, -1)
|
| 305 |
-
sims = cosine_similarity(centroid, cluster_embs)[0]
|
| 306 |
-
|
| 307 |
-
# Get top 3 indices
|
| 308 |
-
top_local_idxs = np.argsort(sims)[-3:][::-1]
|
| 309 |
-
representative_docs[tid] = [documents[idxs[li]] for li in top_local_idxs]
|
| 310 |
-
|
| 311 |
-
return representative_docs
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
def rebuild_topic_keywords(
|
| 315 |
-
topics: list[int],
|
| 316 |
-
documents: list[str],
|
| 317 |
-
) -> dict[int, list[tuple[str, float]]]:
|
| 318 |
-
"""
|
| 319 |
-
Rebuild topic keywords based on updated cluster assignments using CountVectorizer.
|
| 320 |
-
Skips clusters with fewer than 3 documents.
|
| 321 |
-
"""
|
| 322 |
-
cluster_docs: dict = defaultdict(list)
|
| 323 |
-
for doc, t in zip(documents, topics):
|
| 324 |
-
if t != -1:
|
| 325 |
-
cluster_docs[t].append(doc)
|
| 326 |
-
|
| 327 |
-
topic_keywords = {}
|
| 328 |
-
for topic_id, docs in cluster_docs.items():
|
| 329 |
-
if len(docs) < 2:
|
| 330 |
-
continue
|
| 331 |
-
vectorizer = CountVectorizer(stop_words='english', max_features=50)
|
| 332 |
-
try:
|
| 333 |
-
X = vectorizer.fit_transform(docs)
|
| 334 |
-
words = vectorizer.get_feature_names_out()
|
| 335 |
-
scores = X.sum(axis=0).A1
|
| 336 |
-
top_idx = scores.argsort()[::-1][:10]
|
| 337 |
-
topic_keywords[topic_id] = [
|
| 338 |
-
(words[i], float(scores[i])) for i in top_idx
|
| 339 |
-
]
|
| 340 |
-
except Exception as e:
|
| 341 |
-
logger.warning("rebuild_topic_keywords failed for topic %d: %s", topic_id, e)
|
| 342 |
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
"""
|
| 352 |
-
Reassign outlier documents (topic == -1) to the nearest cluster centroid
|
| 353 |
-
if cosine similarity >= similarity_threshold AND cluster size < MAX_CLUSTER_SIZE.
|
| 354 |
-
Otherwise keep as -1.
|
| 355 |
-
"""
|
| 356 |
-
topics = list(topics)
|
| 357 |
-
MAX_CLUSTER_SIZE = 100 # Per instructor spec: max 100 papers per cluster
|
| 358 |
-
|
| 359 |
-
# Build centroid map and current sizes
|
| 360 |
-
cluster_docs: dict[int, list[int]] = {}
|
| 361 |
-
current_sizes: dict[int, int] = {}
|
| 362 |
-
for idx, tid in enumerate(topics):
|
| 363 |
-
if tid != -1:
|
| 364 |
-
cluster_docs.setdefault(tid, []).append(idx)
|
| 365 |
-
current_sizes[tid] = current_sizes.get(tid, 0) + 1
|
| 366 |
-
|
| 367 |
-
if not cluster_docs:
|
| 368 |
-
return topics
|
| 369 |
-
|
| 370 |
-
cluster_ids = list(cluster_docs.keys())
|
| 371 |
-
centroids = np.vstack([
|
| 372 |
-
embeddings[cluster_docs[tid]].mean(axis=0)
|
| 373 |
-
for tid in cluster_ids
|
| 374 |
-
]) # shape: (n_clusters, embed_dim)
|
| 375 |
-
|
| 376 |
-
outlier_indices = [idx for idx, tid in enumerate(topics) if tid == -1]
|
| 377 |
-
reassigned = 0
|
| 378 |
-
|
| 379 |
-
for idx in outlier_indices:
|
| 380 |
-
doc_emb = embeddings[idx].reshape(1, -1)
|
| 381 |
-
sims = cosine_similarity(doc_emb, centroids)[0] # (n_clusters,)
|
| 382 |
-
best_i = int(np.argmax(sims))
|
| 383 |
-
|
| 384 |
-
target_tid = cluster_ids[best_i]
|
| 385 |
-
if sims[best_i] >= similarity_threshold and current_sizes.get(target_tid, 0) < MAX_CLUSTER_SIZE:
|
| 386 |
-
topics[idx] = target_tid
|
| 387 |
-
current_sizes[target_tid] = current_sizes.get(target_tid, 0) + 1
|
| 388 |
-
reassigned += 1
|
| 389 |
-
|
| 390 |
-
logger.info(
|
| 391 |
-
"Outlier reassignment: %d / %d outliers reassigned (threshold=%.2f, max_size=%d).",
|
| 392 |
-
reassigned, len(outlier_indices), similarity_threshold, MAX_CLUSTER_SIZE
|
| 393 |
-
)
|
| 394 |
-
return topics
|
| 395 |
|
| 396 |
|
| 397 |
# ---------------------------------------------------------------------------
|
| 398 |
-
#
|
| 399 |
# ---------------------------------------------------------------------------
|
| 400 |
-
def
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
) -> dict:
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
|
| 455 |
# ---------------------------------------------------------------------------
|
| 456 |
-
#
|
| 457 |
# ---------------------------------------------------------------------------
|
| 458 |
-
def
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
) -> dict:
|
| 462 |
-
|
| 463 |
-
df = load_csv(filepath)
|
| 464 |
-
|
| 465 |
-
# Combined column
|
| 466 |
-
df["combined"] = df["title"].fillna("") + ". " + df["abstract"].fillna("")
|
| 467 |
-
clean_docs = preprocess_text(df["combined"])
|
| 468 |
-
|
| 469 |
-
# New embedding model
|
| 470 |
-
embedding_model = SentenceTransformer("allenai/specter2_base")
|
| 471 |
-
|
| 472 |
-
model = build_bertopic_model(embedding_model, min_topic_size=min_topic_size)
|
| 473 |
-
results = extract_topics(model, clean_docs, embedding_model)
|
| 474 |
-
|
| 475 |
-
return {
|
| 476 |
-
"documents": results
|
| 477 |
-
}
|
| 478 |
-
|
| 479 |
|
| 480 |
|
| 481 |
# ---------------------------------------------------------------------------
|
| 482 |
-
#
|
| 483 |
# ---------------------------------------------------------------------------
|
| 484 |
-
def
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
-
keywords: dict = data["topic_keywords"]
|
| 491 |
-
freq: dict = data["topic_freq"]
|
| 492 |
|
| 493 |
-
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
-
for topic_id, words in sorted(keywords.items()):
|
| 498 |
-
count = freq.get(topic_id, 0)
|
| 499 |
-
kw_str = ", ".join(w for w, _ in words[:top_n_keywords])
|
| 500 |
-
print(f"\n Topic {topic_id:>3} | docs: {count:>4}")
|
| 501 |
-
print(f" Keywords : {kw_str}")
|
| 502 |
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
|
| 508 |
# ---------------------------------------------------------------------------
|
| 509 |
-
#
|
| 510 |
# ---------------------------------------------------------------------------
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
-
if len(sys.argv) < 2:
|
| 515 |
-
print("Usage: python tools.py <path_to_csv> [min_topic_size]")
|
| 516 |
-
sys.exit(1)
|
| 517 |
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
tools.py
|
| 3 |
--------
|
| 4 |
+
Topic-modelling pipeline: SPECTER-2 → UMAP → HDBSCAN
|
| 5 |
+
with multi-objective Bayesian optimisation over UMAP + HDBSCAN
|
| 6 |
+
parameters (§3.1–§3.6 of the methodology guide).
|
| 7 |
+
|
| 8 |
+
No BERTopic wrapper — bare UMAP + HDBSCAN on SPECTER-2 embeddings.
|
| 9 |
"""
|
| 10 |
|
| 11 |
import re
|
| 12 |
import logging
|
| 13 |
import pandas as pd
|
| 14 |
+
import numpy as np
|
| 15 |
from typing import Optional
|
| 16 |
+
from collections import Counter, defaultdict
|
| 17 |
|
|
|
|
| 18 |
from sentence_transformers import SentenceTransformer
|
| 19 |
from umap import UMAP
|
| 20 |
+
from hdbscan import HDBSCAN
|
| 21 |
+
from keybert import KeyBERT
|
| 22 |
+
from sklearn.metrics import adjusted_rand_score
|
| 23 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 24 |
+
import optuna
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# ---------------------------------------------------------------------------
|
| 27 |
# Logging
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 32 |
|
| 33 |
|
| 34 |
# ---------------------------------------------------------------------------
|
| 35 |
+
# Data Loading (unchanged)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
def load_csv(filepath: str) -> pd.DataFrame:
|
| 38 |
df = pd.read_csv(filepath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
df.columns = df.columns.str.lower()
|
| 40 |
+
required = {"title", "abstract"}
|
| 41 |
+
missing = required - set(df.columns)
|
| 42 |
+
if missing:
|
| 43 |
+
raise ValueError(f"CSV missing column(s): {missing}")
|
| 44 |
logger.info("Loaded %d rows from '%s'.", len(df), filepath)
|
| 45 |
return df
|
| 46 |
|
| 47 |
|
| 48 |
# ---------------------------------------------------------------------------
|
| 49 |
+
# §3.1 — Input unit: title + abstract concatenation
|
| 50 |
# ---------------------------------------------------------------------------
|
| 51 |
+
def prepare_documents(df: pd.DataFrame) -> list[str]:
|
| 52 |
+
"""One string per paper: title + abstract (§3.1 input unit)."""
|
| 53 |
+
docs = (df["title"].fillna("") + ". " + df["abstract"].fillna("")).tolist()
|
| 54 |
+
logger.info("Prepared %d title+abstract documents.", len(docs))
|
| 55 |
+
return docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
# ---------------------------------------------------------------------------
|
| 59 |
+
# §3.1 — Embed with SPECTER-2
|
| 60 |
# ---------------------------------------------------------------------------
|
| 61 |
+
def embed_documents(
|
| 62 |
+
docs: list[str],
|
| 63 |
+
model_name: str = "allenai/specter2_base",
|
| 64 |
+
) -> np.ndarray:
|
| 65 |
+
"""Embed with SPECTER-2. Deterministic — no tuning (§3.3)."""
|
| 66 |
+
model = SentenceTransformer(model_name)
|
| 67 |
+
embeddings = model.encode(docs, show_progress_bar=True, batch_size=32)
|
| 68 |
+
logger.info("Embedded %d docs → %s", len(docs), embeddings.shape)
|
| 69 |
+
return embeddings
|
|
|
|
|
|
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# §3.2 — Cluster discipline checks
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
def check_discipline(labels: np.ndarray, n_docs: int) -> dict:
|
| 76 |
+
"""Two hard constraints: max-mass ≤ 25 %, min-size ≥ 5."""
|
| 77 |
+
counts = Counter(int(l) for l in labels)
|
| 78 |
+
unique = [l for l in counts if l != -1]
|
| 79 |
+
|
| 80 |
+
if not unique:
|
| 81 |
+
return dict(max_mass_pct=0, max_mass_ok=False,
|
| 82 |
+
min_size=0, min_size_ok=False,
|
| 83 |
+
n_clusters=0, n_noise=counts.get(-1, 0))
|
| 84 |
+
|
| 85 |
+
max_mass_pct = max(counts[l] / n_docs for l in unique)
|
| 86 |
+
min_size = min(counts[l] for l in unique)
|
| 87 |
+
|
| 88 |
+
return dict(
|
| 89 |
+
max_mass_pct=round(max_mass_pct, 4),
|
| 90 |
+
max_mass_ok=max_mass_pct <= 0.25,
|
| 91 |
+
min_size=int(min_size),
|
| 92 |
+
min_size_ok=min_size >= 5,
|
| 93 |
+
n_clusters=len(unique),
|
| 94 |
+
n_noise=counts.get(-1, 0),
|
| 95 |
+
cluster_sizes={l: counts[l] for l in sorted(unique)},
|
| 96 |
)
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
# ---------------------------------------------------------------------------
|
| 100 |
+
# §3.4 — Quality metrics
|
| 101 |
# ---------------------------------------------------------------------------
|
| 102 |
+
def compute_persistence(clusterer: HDBSCAN) -> float:
|
| 103 |
+
"""Average cluster persistence from the condensed tree."""
|
| 104 |
+
try:
|
| 105 |
+
p = getattr(clusterer, "cluster_persistence_", None)
|
| 106 |
+
if p is not None and len(p) > 0:
|
| 107 |
+
return float(np.mean(p))
|
| 108 |
+
except Exception:
|
| 109 |
+
pass
|
| 110 |
+
return 0.0
|
| 111 |
|
| 112 |
|
| 113 |
+
def compute_dbcv(reduced: np.ndarray, labels: np.ndarray) -> float:
|
| 114 |
+
"""Density-Based Cluster Validity index."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
try:
|
| 116 |
+
from hdbscan.validity import validity_index
|
| 117 |
+
ul = set(labels); ul.discard(-1)
|
| 118 |
+
if len(ul) < 2:
|
| 119 |
+
return -1.0
|
| 120 |
+
return float(validity_index(reduced.astype(np.float64), labels))
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.warning("DBCV failed: %s", e)
|
| 123 |
+
return -1.0
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def compute_stability(embeddings: np.ndarray, params: dict,
|
| 127 |
+
n_seeds: int = 5) -> float:
|
| 128 |
+
"""Cluster-recurrence stability via pairwise ARI across seeds (§3.4)."""
|
| 129 |
+
all_labels = []
|
| 130 |
+
for s in range(n_seeds):
|
| 131 |
+
u = UMAP(n_neighbors=params["n_neighbors"],
|
| 132 |
+
n_components=params["n_components"],
|
| 133 |
+
min_dist=0.0, metric="cosine",
|
| 134 |
+
random_state=s * 7 + 1)
|
| 135 |
+
red = u.fit_transform(embeddings)
|
| 136 |
+
h = HDBSCAN(min_cluster_size=params["min_cluster_size"],
|
| 137 |
+
min_samples=params["min_samples"],
|
| 138 |
+
metric="euclidean",
|
| 139 |
+
cluster_selection_method=params["csm"],
|
| 140 |
+
cluster_selection_epsilon=params["cse"])
|
| 141 |
+
all_labels.append(h.fit_predict(red))
|
| 142 |
+
|
| 143 |
+
aris = []
|
| 144 |
+
for i in range(len(all_labels)):
|
| 145 |
+
for j in range(i + 1, len(all_labels)):
|
| 146 |
+
aris.append(adjusted_rand_score(all_labels[i], all_labels[j]))
|
| 147 |
+
return float(np.mean(aris)) if aris else 0.0
|
| 148 |
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
# §3.4 — Bayesian optimisation objective
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
def _objective(trial, embeddings, n_docs):
|
| 154 |
+
"""Single Optuna trial — returns (persistence, dbcv, stability_placeholder)."""
|
| 155 |
+
n_neighbors = trial.suggest_categorical("n_neighbors", [5, 10, 15, 30, 50])
|
| 156 |
+
n_components = trial.suggest_int("n_components", 5, 10)
|
| 157 |
+
mcs = trial.suggest_int(
|
| 158 |
+
"min_cluster_size",
|
| 159 |
+
max(5, int(0.01 * n_docs)),
|
| 160 |
+
max(20, int(0.05 * n_docs)),
|
| 161 |
+
)
|
| 162 |
+
ms = trial.suggest_int("min_samples", 1, mcs)
|
| 163 |
+
csm = trial.suggest_categorical("csm", ["eom", "leaf"])
|
| 164 |
+
cse = trial.suggest_float("cse", 0.0, 0.3, step=0.05)
|
| 165 |
|
| 166 |
+
params = dict(n_neighbors=n_neighbors, n_components=n_components,
|
| 167 |
+
min_cluster_size=mcs, min_samples=ms, csm=csm, cse=cse)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
u = UMAP(n_neighbors=n_neighbors, n_components=n_components,
|
| 170 |
+
min_dist=0.0, metric="cosine", random_state=42)
|
| 171 |
+
red = u.fit_transform(embeddings)
|
|
|
|
| 172 |
|
| 173 |
+
h = HDBSCAN(min_cluster_size=mcs, min_samples=ms, metric="euclidean",
|
| 174 |
+
cluster_selection_method=csm,
|
| 175 |
+
cluster_selection_epsilon=cse,
|
| 176 |
+
allow_single_cluster=False, gen_min_span_tree=True)
|
| 177 |
+
labels = h.fit_predict(red)
|
| 178 |
|
| 179 |
+
disc = check_discipline(labels, n_docs)
|
| 180 |
+
trial.set_user_attr("params", params)
|
| 181 |
+
trial.set_user_attr("discipline", disc)
|
| 182 |
+
trial.set_user_attr("labels", labels.tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
# Hard-constraint violation → worst scores
|
| 185 |
+
if not disc["max_mass_ok"] or not disc["min_size_ok"]:
|
| 186 |
+
trial.set_user_attr("pass", False)
|
| 187 |
+
return 0.0, -1.0, 0.0
|
| 188 |
|
| 189 |
+
trial.set_user_attr("pass", True)
|
| 190 |
+
pers = compute_persistence(h)
|
| 191 |
+
dbcv = compute_dbcv(red, labels)
|
| 192 |
+
trial.set_user_attr("persistence", pers)
|
| 193 |
+
trial.set_user_attr("dbcv", dbcv)
|
| 194 |
+
return pers, dbcv, 0.5 # stability computed only for winner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
# ---------------------------------------------------------------------------
|
| 198 |
+
# §3.4 — Run the full Bayesian loop
|
| 199 |
# ---------------------------------------------------------------------------
|
| 200 |
+
def run_bayesian_optimisation(
|
| 201 |
+
embeddings: np.ndarray,
|
| 202 |
+
n_trials: int = 50,
|
| 203 |
+
progress_callback=None,
|
| 204 |
) -> dict:
|
| 205 |
+
n_docs = len(embeddings)
|
| 206 |
+
study = optuna.create_study(
|
| 207 |
+
directions=["maximize", "maximize", "maximize"],
|
| 208 |
+
sampler=optuna.samplers.TPESampler(seed=42, multivariate=True),
|
| 209 |
+
study_name="specter2_umap_hdbscan",
|
| 210 |
+
)
|
| 211 |
+
trial_log = []
|
| 212 |
+
|
| 213 |
+
def _cb(study, trial):
|
| 214 |
+
d = trial.user_attrs.get("discipline", {})
|
| 215 |
+
entry = dict(
|
| 216 |
+
trial=trial.number,
|
| 217 |
+
params=trial.user_attrs.get("params", {}),
|
| 218 |
+
discipline_pass=trial.user_attrs.get("pass", False),
|
| 219 |
+
persistence=trial.user_attrs.get("persistence", 0.0),
|
| 220 |
+
dbcv=trial.user_attrs.get("dbcv", -1.0),
|
| 221 |
+
n_clusters=d.get("n_clusters", 0),
|
| 222 |
+
max_mass_pct=d.get("max_mass_pct", 0.0),
|
| 223 |
+
min_size=d.get("min_size", 0),
|
| 224 |
+
n_noise=d.get("n_noise", 0),
|
| 225 |
+
values=list(trial.values) if trial.values else [],
|
| 226 |
+
)
|
| 227 |
+
trial_log.append(entry)
|
| 228 |
+
if progress_callback:
|
| 229 |
+
progress_callback(trial.number + 1, n_trials, entry)
|
| 230 |
+
|
| 231 |
+
for i in range(n_trials):
|
| 232 |
+
study.optimize(
|
| 233 |
+
lambda t: _objective(t, embeddings, n_docs),
|
| 234 |
+
n_trials=1, callbacks=[_cb], show_progress_bar=False,
|
| 235 |
+
)
|
| 236 |
+
# §3.6 convergence: 3 consecutive passing within 5 % of best
|
| 237 |
+
passing = [e for e in trial_log if e["discipline_pass"]]
|
| 238 |
+
if len(passing) >= 3 and i >= 19:
|
| 239 |
+
best_p = max(e["persistence"] for e in passing)
|
| 240 |
+
if best_p > 0:
|
| 241 |
+
last3 = passing[-3:]
|
| 242 |
+
if all(abs(e["persistence"] - best_p) / best_p < 0.05
|
| 243 |
+
for e in last3):
|
| 244 |
+
logger.info("Converged at trial %d.", i + 1)
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
# Select best passing trial (max persistence, then DBCV)
|
| 248 |
+
passing_trials = [t for t in study.trials
|
| 249 |
+
if t.user_attrs.get("pass", False)]
|
| 250 |
+
if passing_trials:
|
| 251 |
+
best = max(passing_trials, key=lambda t: (t.values[0], t.values[1]))
|
| 252 |
+
else:
|
| 253 |
+
logger.warning("No trial passed discipline — using last trial.")
|
| 254 |
+
best = study.trials[-1]
|
| 255 |
+
|
| 256 |
+
bp = best.user_attrs["params"]
|
| 257 |
+
labels = np.array(best.user_attrs["labels"])
|
| 258 |
+
stability = compute_stability(embeddings, bp, n_seeds=5)
|
| 259 |
+
|
| 260 |
+
return dict(
|
| 261 |
+
best_params=bp, best_labels=labels,
|
| 262 |
+
best_trial=best.number,
|
| 263 |
+
persistence=best.user_attrs.get("persistence", 0.0),
|
| 264 |
+
dbcv=best.user_attrs.get("dbcv", -1.0),
|
| 265 |
+
stability=stability,
|
| 266 |
+
discipline=best.user_attrs.get("discipline", {}),
|
| 267 |
+
trial_log=trial_log,
|
| 268 |
+
n_trials_run=len(trial_log),
|
| 269 |
+
)
|
| 270 |
|
| 271 |
|
| 272 |
# ---------------------------------------------------------------------------
|
| 273 |
+
# §3.1 — 2-D UMAP for visualisation
|
| 274 |
# ---------------------------------------------------------------------------
|
| 275 |
+
def compute_2d_umap(embeddings: np.ndarray, seed: int = 42) -> np.ndarray:
|
| 276 |
+
return UMAP(n_neighbors=15, n_components=2, min_dist=0.1,
|
| 277 |
+
metric="cosine", random_state=seed).fit_transform(embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
|
| 280 |
# ---------------------------------------------------------------------------
|
| 281 |
+
# §3.1 — KeyBERT keyphrase extraction per cluster (3–5 phrases)
|
| 282 |
# ---------------------------------------------------------------------------
|
| 283 |
+
def extract_keyphrases(docs: list[str], labels: np.ndarray,
|
| 284 |
+
top_n: int = 5) -> dict:
|
| 285 |
+
kw = KeyBERT(model="all-MiniLM-L6-v2")
|
| 286 |
+
cluster_docs = defaultdict(list)
|
| 287 |
+
for doc, lab in zip(docs, labels):
|
| 288 |
+
if lab != -1:
|
| 289 |
+
cluster_docs[int(lab)].append(doc)
|
| 290 |
+
out = {}
|
| 291 |
+
for cid, cdocs in cluster_docs.items():
|
| 292 |
+
try:
|
| 293 |
+
out[cid] = kw.extract_keywords(
|
| 294 |
+
" ".join(cdocs), keyphrase_ngram_range=(1, 3),
|
| 295 |
+
stop_words="english", top_n=top_n,
|
| 296 |
+
use_mmr=True, diversity=0.5)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
logger.warning("KeyBERT cluster %d: %s", cid, e)
|
| 299 |
+
out[cid] = []
|
| 300 |
+
return out
|
| 301 |
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
# ---------------------------------------------------------------------------
|
| 304 |
+
# §3.1 — Strong / weak member counts via HDBSCAN probabilities
|
| 305 |
+
# ---------------------------------------------------------------------------
|
| 306 |
+
def strong_weak_members(labels: np.ndarray,
|
| 307 |
+
probabilities: np.ndarray) -> dict:
|
| 308 |
+
mem = defaultdict(lambda: {"strong": 0, "weak": 0})
|
| 309 |
+
for lab, prob in zip(labels, probabilities):
|
| 310 |
+
if lab == -1:
|
| 311 |
continue
|
| 312 |
+
cid = int(lab)
|
| 313 |
+
if prob >= 0.5:
|
| 314 |
+
mem[cid]["strong"] += 1
|
| 315 |
+
else:
|
| 316 |
+
mem[cid]["weak"] += 1
|
| 317 |
+
return dict(mem)
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
# ---------------------------------------------------------------------------
|
| 321 |
+
# §3.2 — Outlier reduction: reassign noise to nearest cluster (≤ 25 %)
|
| 322 |
+
# ---------------------------------------------------------------------------
|
| 323 |
+
def outlier_reduction(labels: np.ndarray, reduced: np.ndarray,
|
| 324 |
+
n_docs: int) -> np.ndarray:
|
| 325 |
+
labels = labels.copy()
|
| 326 |
+
cap = int(0.25 * n_docs)
|
| 327 |
+
cdocs = defaultdict(list)
|
| 328 |
+
for i, l in enumerate(labels):
|
| 329 |
+
if l != -1:
|
| 330 |
+
cdocs[int(l)].append(i)
|
| 331 |
+
if not cdocs:
|
| 332 |
+
return labels
|
| 333 |
+
cids = list(cdocs.keys())
|
| 334 |
+
centroids = np.vstack([reduced[cdocs[c]].mean(axis=0) for c in cids])
|
| 335 |
+
noise = [i for i, l in enumerate(labels) if l == -1]
|
| 336 |
+
moved = 0
|
| 337 |
+
for idx in noise:
|
| 338 |
+
dists = np.linalg.norm(centroids - reduced[idx], axis=1)
|
| 339 |
+
for best in np.argsort(dists):
|
| 340 |
+
tgt = cids[best]
|
| 341 |
+
if len(cdocs[tgt]) < cap:
|
| 342 |
+
labels[idx] = tgt
|
| 343 |
+
cdocs[tgt].append(idx)
|
| 344 |
+
moved += 1
|
| 345 |
+
break
|
| 346 |
+
logger.info("Outlier reduction: %d / %d noise reassigned.", moved, len(noise))
|
| 347 |
+
return labels
|
| 348 |
|
| 349 |
|
| 350 |
# ---------------------------------------------------------------------------
|
| 351 |
+
# Representative docs (top-3 by centroid proximity)
|
| 352 |
# ---------------------------------------------------------------------------
|
| 353 |
+
def get_representative_docs(labels, embeddings, docs, top_n=3):
|
| 354 |
+
cdocs = defaultdict(list)
|
| 355 |
+
for i, l in enumerate(labels):
|
| 356 |
+
if l != -1:
|
| 357 |
+
cdocs[int(l)].append(i)
|
| 358 |
+
out = {}
|
| 359 |
+
for cid, idxs in cdocs.items():
|
| 360 |
+
ce = embeddings[idxs].mean(axis=0).reshape(1, -1)
|
| 361 |
+
sims = cosine_similarity(ce, embeddings[idxs])[0]
|
| 362 |
+
top = np.argsort(sims)[-top_n:][::-1]
|
| 363 |
+
out[cid] = [docs[idxs[t]] for t in top]
|
| 364 |
+
return out
|
| 365 |
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
+
# ---------------------------------------------------------------------------
|
| 368 |
+
# High-level pipeline entry point
|
| 369 |
+
# ---------------------------------------------------------------------------
|
| 370 |
+
def run_topic_modeling(filepath: str, n_trials: int = 50,
|
| 371 |
+
progress_callback=None) -> dict:
|
| 372 |
+
# 1. Load
|
| 373 |
+
df = load_csv(filepath)
|
| 374 |
+
docs = prepare_documents(df)
|
| 375 |
+
n_docs = len(docs)
|
| 376 |
+
|
| 377 |
+
# 2. Embed (deterministic)
|
| 378 |
+
embeddings = embed_documents(docs)
|
| 379 |
+
|
| 380 |
+
# 3. Bayesian optimisation (§3.4)
|
| 381 |
+
opt = run_bayesian_optimisation(embeddings, n_trials, progress_callback)
|
| 382 |
+
bp = opt["best_params"]
|
| 383 |
+
labels = opt["best_labels"]
|
| 384 |
+
|
| 385 |
+
# 4. Re-run winner for clusterer object (probabilities)
|
| 386 |
+
u = UMAP(n_neighbors=bp["n_neighbors"], n_components=bp["n_components"],
|
| 387 |
+
min_dist=0.0, metric="cosine", random_state=42)
|
| 388 |
+
red = u.fit_transform(embeddings)
|
| 389 |
+
h = HDBSCAN(min_cluster_size=bp["min_cluster_size"],
|
| 390 |
+
min_samples=bp["min_samples"], metric="euclidean",
|
| 391 |
+
cluster_selection_method=bp["csm"],
|
| 392 |
+
cluster_selection_epsilon=bp["cse"],
|
| 393 |
+
allow_single_cluster=False)
|
| 394 |
+
h.fit(red)
|
| 395 |
+
|
| 396 |
+
# 5. Outlier reduction (§3.2 — clusters < 5 reassigned)
|
| 397 |
+
labels = outlier_reduction(labels, red, n_docs)
|
| 398 |
+
|
| 399 |
+
# 6. Strong / weak (§3.1)
|
| 400 |
+
sw = strong_weak_members(labels, h.probabilities_)
|
| 401 |
+
|
| 402 |
+
# 7. 2-D UMAP (§3.1)
|
| 403 |
+
umap_2d = compute_2d_umap(embeddings)
|
| 404 |
+
|
| 405 |
+
# 8. KeyBERT keyphrases (§3.1)
|
| 406 |
+
keyphrases = extract_keyphrases(docs, labels)
|
| 407 |
+
|
| 408 |
+
# 9. Rep docs
|
| 409 |
+
rep_docs = get_representative_docs(labels, embeddings, docs)
|
| 410 |
+
|
| 411 |
+
# 10. Final discipline
|
| 412 |
+
disc = check_discipline(labels, n_docs)
|
| 413 |
+
|
| 414 |
+
return dict(
|
| 415 |
+
documents=docs, labels=labels.tolist(),
|
| 416 |
+
keyphrases=keyphrases, representative_docs=rep_docs,
|
| 417 |
+
membership=sw, umap_2d=umap_2d.tolist(),
|
| 418 |
+
discipline=disc, best_params=bp,
|
| 419 |
+
metrics=dict(persistence=opt["persistence"],
|
| 420 |
+
dbcv=opt["dbcv"],
|
| 421 |
+
stability=opt["stability"]),
|
| 422 |
+
trial_log=opt["trial_log"],
|
| 423 |
+
n_trials_run=opt["n_trials_run"],
|
| 424 |
+
best_trial=opt["best_trial"],
|
| 425 |
+
n_docs=n_docs,
|
| 426 |
+
embeddings=embeddings,
|
| 427 |
+
)
|