Update tools.py
Browse files
tools.py
CHANGED
|
@@ -46,10 +46,11 @@ import pandas as pd
|
|
| 46 |
import plotly.express as px
|
| 47 |
import plotly.graph_objects as go
|
| 48 |
import plotly.figure_factory as ff
|
| 49 |
-
from sklearn.cluster import AgglomerativeClustering
|
| 50 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 51 |
from sklearn.preprocessing import normalize
|
| 52 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
| 53 |
|
| 54 |
from langchain_core.tools import tool
|
| 55 |
from langchain_core.prompts import PromptTemplate
|
|
@@ -69,7 +70,9 @@ MISTRAL_API_KEY: str = os.environ.get("MISTRAL_API_KEY", "")
|
|
| 69 |
MODEL_NAME: str = "mistral-small-latest"
|
| 70 |
GROQ_API_KEY: str = os.environ.get("GROQ_API_KEY", "")
|
| 71 |
GROQ_MODEL_NAME: str = os.environ.get("GROQ_MODEL_NAME", "llama-3.3-70b-versatile")
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
BASE_DIR: Path = Path(__file__).resolve().parent
|
| 74 |
OUTPUT_DIR: Path = BASE_DIR / "outputs"
|
| 75 |
N_EVIDENCE: int = 5 # sentences kept per cluster centroid
|
|
@@ -78,10 +81,17 @@ RANDOM_SEED: int = 42
|
|
| 78 |
LLM_TIMEOUT_S: int = 45
|
| 79 |
LLM_MAX_RETRIES: int = 3
|
| 80 |
MAX_LABEL_CLUSTERS: int = 60
|
| 81 |
-
MIN_CLUSTER_SIZE_FOR_LABEL: int =
|
| 82 |
MAX_TOOL_RETURN_PREVIEW: int = 12
|
| 83 |
PROVIDER_RETRY_ATTEMPTS: int = 3
|
| 84 |
PROVIDER_RETRY_BASE_DELAY_S: float = 1.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
# Run configurations — keys map to source columns
|
| 87 |
RUN_CONFIGS: dict[str, list[str]] = {
|
|
@@ -198,18 +208,33 @@ def _texts_for_candidates(df: pd.DataFrame, candidates: list[str]) -> tuple[list
|
|
| 198 |
|
| 199 |
|
| 200 |
def _embed(sentences: list[str]) -> np.ndarray:
|
| 201 |
-
"""Encode sentences to L2-normalised
|
| 202 |
-
model = SentenceTransformer(EMBED_MODEL)
|
| 203 |
raw = model.encode(sentences, show_progress_bar=False, batch_size=64)
|
| 204 |
return normalize(raw, norm="l2") # unit-norm -> cosine = dot product
|
| 205 |
|
| 206 |
|
| 207 |
-
def
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
| 209 |
metric="cosine",
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
).fit_predict(embeddings)
|
| 214 |
|
| 215 |
|
|
@@ -234,14 +259,14 @@ def _llm() -> ChatMistralAI:
|
|
| 234 |
)
|
| 235 |
|
| 236 |
|
| 237 |
-
def _llm_groq():
|
| 238 |
if ChatGroq is None:
|
| 239 |
raise RuntimeError(
|
| 240 |
"langchain-groq is not installed. Install dependencies from requirements.txt "
|
| 241 |
"to enable Groq topic-label verification."
|
| 242 |
)
|
| 243 |
return ChatGroq(
|
| 244 |
-
model=
|
| 245 |
api_key=GROQ_API_KEY,
|
| 246 |
temperature=0.2,
|
| 247 |
timeout=LLM_TIMEOUT_S,
|
|
@@ -249,8 +274,12 @@ def _llm_groq():
|
|
| 249 |
)
|
| 250 |
|
| 251 |
|
| 252 |
-
def
|
| 253 |
-
return bool(GROQ_API_KEY) and ChatGroq is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
def _to_float(value: object, fallback: float = 0.0) -> float:
|
|
@@ -345,7 +374,11 @@ def _chart_top_words(summaries: list[dict]) -> go.Figure:
|
|
| 345 |
|
| 346 |
|
| 347 |
def _chart_hierarchy(labels: list[int], embeddings: np.ndarray) -> go.Figure:
|
| 348 |
-
unique = sorted(set(labels))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
labels_arr = np.array(labels)
|
| 350 |
centroids = np.vstack([
|
| 351 |
_centroid(embeddings[labels_arr == lbl])
|
|
@@ -362,7 +395,11 @@ def _chart_hierarchy(labels: list[int], embeddings: np.ndarray) -> go.Figure:
|
|
| 362 |
|
| 363 |
|
| 364 |
def _chart_heatmap(labels: list[int], embeddings: np.ndarray) -> go.Figure:
|
| 365 |
-
unique = sorted(set(labels))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
labels_arr = np.array(labels)
|
| 367 |
centroids = np.vstack([
|
| 368 |
_centroid(embeddings[labels_arr == lbl])
|
|
@@ -460,21 +497,30 @@ def load_scopus_csv(filepath: str) -> dict:
|
|
| 460 |
# ============================================================================
|
| 461 |
|
| 462 |
@tool
|
| 463 |
-
def run_bertopic_discovery(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
"""
|
| 465 |
-
Embed sentences, cluster with
|
| 466 |
and generate four Plotly charts.
|
| 467 |
|
| 468 |
Saved artefacts
|
| 469 |
---------------
|
| 470 |
-
emb.npy : (N,
|
| 471 |
sent_labels.npy : (N,) int32 per-sentence cluster label [BUG 1 FIX]
|
| 472 |
summaries.json : list of cluster dicts with evidence sentences
|
| 473 |
|
| 474 |
Parameters
|
| 475 |
----------
|
| 476 |
run_key : str — "abstract" or "title" or "keywords"
|
| 477 |
-
threshold : float —
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
Returns
|
| 480 |
-------
|
|
@@ -527,8 +573,16 @@ def run_bertopic_discovery(run_key: str, threshold: float = DISTANCE_THRESH) ->
|
|
| 527 |
embeddings = _embed(sentences)
|
| 528 |
np.save(str(rdir / "emb.npy"), embeddings)
|
| 529 |
|
| 530 |
-
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
# FIX BUG 1 — persist per-sentence label array so Tool 4 can build
|
| 534 |
# correct cluster masks without any guesswork or scaffolding.
|
|
@@ -536,17 +590,39 @@ def run_bertopic_discovery(run_key: str, threshold: float = DISTANCE_THRESH) ->
|
|
| 536 |
|
| 537 |
labels_arr = np.array(labels)
|
| 538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
def _cluster_summary(cid: int) -> dict:
|
| 540 |
mask = labels_arr == cid
|
| 541 |
c_emb = embeddings[mask]
|
|
|
|
| 542 |
c_sent = list(np.array(sentences)[mask])
|
| 543 |
ctroid = _centroid(c_emb)
|
| 544 |
top_idx = _top_k_indices(c_emb, ctroid, N_EVIDENCE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
return {
|
| 546 |
"cluster_id": int(cid),
|
| 547 |
"size": int(mask.sum()),
|
| 548 |
-
"cx": float(
|
| 549 |
-
"cy": float(
|
| 550 |
"evidence": list(np.array(c_sent)[top_idx]),
|
| 551 |
}
|
| 552 |
|
|
@@ -565,6 +641,9 @@ def run_bertopic_discovery(run_key: str, threshold: float = DISTANCE_THRESH) ->
|
|
| 565 |
"n_clusters": int(len(unique_ids)),
|
| 566 |
"n_sentences": int(len(sentences)),
|
| 567 |
"threshold": threshold,
|
|
|
|
|
|
|
|
|
|
| 568 |
"chart_paths": chart_paths,
|
| 569 |
"summaries_path": str(rdir / "summaries.json"),
|
| 570 |
"embeddings_path": str(rdir / "emb.npy"),
|
|
@@ -663,8 +742,22 @@ def label_topics_with_llm(run_key: str) -> dict:
|
|
| 663 |
"groq_confidence": 0.0,
|
| 664 |
"groq_reasoning": "",
|
| 665 |
"groq_niche": False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
"verification_done": False,
|
| 667 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
}
|
| 669 |
|
| 670 |
labelled = list(map(_label_one, selected))
|
|
@@ -679,6 +772,8 @@ def label_topics_with_llm(run_key: str) -> dict:
|
|
| 679 |
"confidence": r.get("confidence"),
|
| 680 |
"mistral_label": r.get("mistral_label", ""),
|
| 681 |
"groq_label": r.get("groq_label", ""),
|
|
|
|
|
|
|
| 682 |
"size": r.get("size"),
|
| 683 |
"niche": r.get("niche", False),
|
| 684 |
},
|
|
@@ -692,8 +787,8 @@ def label_topics_with_llm(run_key: str) -> dict:
|
|
| 692 |
"total_clusters": len(summaries),
|
| 693 |
"selected_clusters": len(selected),
|
| 694 |
"skipped_clusters": max(0, len(summaries) - len(selected)),
|
| 695 |
-
"groq_enabled":
|
| 696 |
-
"mode_note": "Single-model labeling complete (Mistral). Send VERIFY in Phase 2 to run Groq verification.",
|
| 697 |
"labels_preview": preview,
|
| 698 |
}
|
| 699 |
|
|
@@ -702,7 +797,7 @@ def label_topics_with_llm(run_key: str) -> dict:
|
|
| 702 |
def verify_topic_labels_with_groq(run_key: str) -> dict:
|
| 703 |
"""
|
| 704 |
Run Groq topic labeling for already-labeled topics and append comparison fields
|
| 705 |
-
into labels.json so UI review table can show
|
| 706 |
|
| 707 |
Parameters
|
| 708 |
----------
|
|
@@ -717,15 +812,16 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
|
|
| 717 |
labels_path = rdir / "labels.json"
|
| 718 |
summaries_path = rdir / "summaries.json"
|
| 719 |
|
| 720 |
-
if not
|
| 721 |
return {
|
| 722 |
"run_key": run_key,
|
| 723 |
"labels_path": str(labels_path),
|
| 724 |
"verified_count": 0,
|
| 725 |
"labels_preview": [],
|
| 726 |
"error": (
|
| 727 |
-
"GROQ_API_KEY is missing or langchain-groq is unavailable. "
|
| 728 |
-
"Set GROQ_API_KEY and
|
|
|
|
| 729 |
),
|
| 730 |
}
|
| 731 |
|
|
@@ -765,7 +861,8 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
|
|
| 765 |
labels_data,
|
| 766 |
))
|
| 767 |
|
| 768 |
-
|
|
|
|
| 769 |
|
| 770 |
def _evidence_block(summary: dict) -> str:
|
| 771 |
return "\n".join(
|
|
@@ -773,29 +870,35 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
|
|
| 773 |
for i, s in enumerate(summary.get("evidence", []))
|
| 774 |
)
|
| 775 |
|
| 776 |
-
def _label_with_groq(row: dict) -> tuple[int, dict]:
|
| 777 |
cid = int(row.get("cluster_id", -1))
|
| 778 |
summary = summary_by_id[cid]
|
| 779 |
-
|
| 780 |
"cluster_id": summary["cluster_id"],
|
| 781 |
"size": summary["size"],
|
| 782 |
"evidence": _evidence_block(summary),
|
| 783 |
-
}
|
| 784 |
-
|
|
|
|
|
|
|
| 785 |
|
| 786 |
groq_pairs = list(map(_label_with_groq, target_rows))
|
| 787 |
-
|
|
|
|
| 788 |
|
| 789 |
def _merge_row(row: dict) -> dict:
|
| 790 |
cid = int(row.get("cluster_id", -1))
|
| 791 |
-
|
| 792 |
-
|
|
|
|
|
|
|
| 793 |
mistral_label = str(row.get("mistral_label") or row.get("label", "")).strip()
|
| 794 |
-
|
|
|
|
| 795 |
is_agreement = (
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
)
|
| 800 |
|
| 801 |
return {
|
|
@@ -808,18 +911,30 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
|
|
| 808 |
),
|
| 809 |
"mistral_reasoning": row.get("mistral_reasoning") or row.get("reasoning", ""),
|
| 810 |
"mistral_niche": bool(row.get("mistral_niche", row.get("niche", False))),
|
| 811 |
-
"groq_label":
|
| 812 |
-
"groq_category":
|
| 813 |
-
"groq_confidence": _to_float(
|
| 814 |
-
"groq_reasoning":
|
| 815 |
-
"groq_niche": bool(
|
| 816 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
"verification_note": (
|
| 818 |
-
"Mistral and Groq labels match."
|
| 819 |
if is_agreement
|
| 820 |
-
else "
|
| 821 |
)
|
| 822 |
-
if
|
| 823 |
else "Groq labeling unavailable for this topic.",
|
| 824 |
}
|
| 825 |
|
|
@@ -832,13 +947,18 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
|
|
| 832 |
lambda r: {
|
| 833 |
"cluster_id": r.get("cluster_id"),
|
| 834 |
"mistral_label": r.get("mistral_label", ""),
|
| 835 |
-
"
|
|
|
|
| 836 |
"verification_note": r.get("verification_note", ""),
|
| 837 |
},
|
| 838 |
verified_rows[:MAX_TOOL_RETURN_PREVIEW],
|
| 839 |
))
|
| 840 |
|
| 841 |
-
verified_count = sum(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
|
| 843 |
return {
|
| 844 |
"run_key": run_key,
|
|
@@ -1044,7 +1164,7 @@ def verify_taxonomy_mapping_with_groq(run_key: str) -> dict:
|
|
| 1044 |
run_key, taxonomy_path, verification_path,
|
| 1045 |
verified_count, mapping_preview
|
| 1046 |
"""
|
| 1047 |
-
if not
|
| 1048 |
return {
|
| 1049 |
"run_key": run_key,
|
| 1050 |
"taxonomy_path": str(_run_dir(run_key) / "taxonomy_map.json"),
|
|
@@ -1088,7 +1208,7 @@ def verify_taxonomy_mapping_with_groq(run_key: str) -> dict:
|
|
| 1088 |
taxonomy_map = _load_json(taxonomy_path)
|
| 1089 |
taxonomy_str = "\n".join(f" - {cat}" for cat in PAJAIS_TAXONOMY)
|
| 1090 |
|
| 1091 |
-
chain_groq = _TAXONOMY_PROMPT | _llm_groq() | JsonOutputParser()
|
| 1092 |
|
| 1093 |
def _map_theme_with_groq(theme: dict) -> dict:
|
| 1094 |
return _invoke_with_retries(lambda: chain_groq.invoke({
|
|
|
|
| 46 |
import plotly.express as px
|
| 47 |
import plotly.graph_objects as go
|
| 48 |
import plotly.figure_factory as ff
|
|
|
|
| 49 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 50 |
from sklearn.preprocessing import normalize
|
| 51 |
from sentence_transformers import SentenceTransformer
|
| 52 |
+
import hdbscan
|
| 53 |
+
import umap
|
| 54 |
|
| 55 |
from langchain_core.tools import tool
|
| 56 |
from langchain_core.prompts import PromptTemplate
|
|
|
|
| 70 |
MODEL_NAME: str = "mistral-small-latest"
|
| 71 |
GROQ_API_KEY: str = os.environ.get("GROQ_API_KEY", "")
|
| 72 |
GROQ_MODEL_NAME: str = os.environ.get("GROQ_MODEL_NAME", "llama-3.3-70b-versatile")
|
| 73 |
+
GROQ_OLLAMA_MODEL_NAME: str = os.environ.get("GROQ_OLLAMA_MODEL_NAME", "llama-3.3-70b-versatile")
|
| 74 |
+
GROQ_GPT_MODEL_NAME: str = os.environ.get("GROQ_GPT_MODEL_NAME", "openai/gpt-oss-120b")
|
| 75 |
+
EMBED_MODEL: str = "allenai/specter2_base"
|
| 76 |
BASE_DIR: Path = Path(__file__).resolve().parent
|
| 77 |
OUTPUT_DIR: Path = BASE_DIR / "outputs"
|
| 78 |
N_EVIDENCE: int = 5 # sentences kept per cluster centroid
|
|
|
|
| 81 |
LLM_TIMEOUT_S: int = 45
|
| 82 |
LLM_MAX_RETRIES: int = 3
|
| 83 |
MAX_LABEL_CLUSTERS: int = 60
|
| 84 |
+
MIN_CLUSTER_SIZE_FOR_LABEL: int = 20
|
| 85 |
MAX_TOOL_RETURN_PREVIEW: int = 12
|
| 86 |
PROVIDER_RETRY_ATTEMPTS: int = 3
|
| 87 |
PROVIDER_RETRY_BASE_DELAY_S: float = 1.5
|
| 88 |
+
HDBSCAN_MIN_CLUSTER_SIZE: int = 20
|
| 89 |
+
HDBSCAN_MIN_SAMPLES: int = 5
|
| 90 |
+
HDBSCAN_MAX_CLUSTER_SIZE: int = 120
|
| 91 |
+
UMAP_N_NEIGHBORS: int = 15
|
| 92 |
+
UMAP_MIN_DIST: float = 0.0
|
| 93 |
+
UMAP_N_COMPONENTS_CLUSTER: int = 5
|
| 94 |
+
UMAP_N_COMPONENTS_VIZ: int = 2
|
| 95 |
|
| 96 |
# Run configurations — keys map to source columns
|
| 97 |
RUN_CONFIGS: dict[str, list[str]] = {
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
def _embed(sentences: list[str]) -> np.ndarray:
|
| 211 |
+
"""Encode sentences to L2-normalised SPECTER2 vectors."""
|
| 212 |
+
model = SentenceTransformer(EMBED_MODEL, trust_remote_code=True)
|
| 213 |
raw = model.encode(sentences, show_progress_bar=False, batch_size=64)
|
| 214 |
return normalize(raw, norm="l2") # unit-norm -> cosine = dot product
|
| 215 |
|
| 216 |
|
| 217 |
+
def _umap_reduce(embeddings: np.ndarray, n_components: int) -> np.ndarray:
|
| 218 |
+
reducer = umap.UMAP(
|
| 219 |
+
n_neighbors=UMAP_N_NEIGHBORS,
|
| 220 |
+
min_dist=UMAP_MIN_DIST,
|
| 221 |
+
n_components=n_components,
|
| 222 |
metric="cosine",
|
| 223 |
+
random_state=RANDOM_SEED,
|
| 224 |
+
)
|
| 225 |
+
return reducer.fit_transform(embeddings)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _cluster(embeddings: np.ndarray,
|
| 229 |
+
min_cluster_size: int,
|
| 230 |
+
max_cluster_size: int,
|
| 231 |
+
min_samples: int) -> np.ndarray:
|
| 232 |
+
return hdbscan.HDBSCAN(
|
| 233 |
+
min_cluster_size=min_cluster_size,
|
| 234 |
+
min_samples=min_samples,
|
| 235 |
+
metric="euclidean",
|
| 236 |
+
cluster_selection_method="eom",
|
| 237 |
+
max_cluster_size=max_cluster_size,
|
| 238 |
).fit_predict(embeddings)
|
| 239 |
|
| 240 |
|
|
|
|
| 259 |
)
|
| 260 |
|
| 261 |
|
| 262 |
+
def _llm_groq(model_name: str):
|
| 263 |
if ChatGroq is None:
|
| 264 |
raise RuntimeError(
|
| 265 |
"langchain-groq is not installed. Install dependencies from requirements.txt "
|
| 266 |
"to enable Groq topic-label verification."
|
| 267 |
)
|
| 268 |
return ChatGroq(
|
| 269 |
+
model=model_name,
|
| 270 |
api_key=GROQ_API_KEY,
|
| 271 |
temperature=0.2,
|
| 272 |
timeout=LLM_TIMEOUT_S,
|
|
|
|
| 274 |
)
|
| 275 |
|
| 276 |
|
| 277 |
+
def _groq_ollama_enabled() -> bool:
|
| 278 |
+
return bool(GROQ_API_KEY) and ChatGroq is not None and bool(GROQ_OLLAMA_MODEL_NAME)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _groq_gpt_enabled() -> bool:
|
| 282 |
+
return bool(GROQ_API_KEY) and ChatGroq is not None and bool(GROQ_GPT_MODEL_NAME)
|
| 283 |
|
| 284 |
|
| 285 |
def _to_float(value: object, fallback: float = 0.0) -> float:
|
|
|
|
| 374 |
|
| 375 |
|
| 376 |
def _chart_hierarchy(labels: list[int], embeddings: np.ndarray) -> go.Figure:
|
| 377 |
+
unique = sorted(filter(lambda v: v != -1, set(labels)))
|
| 378 |
+
if not unique:
|
| 379 |
+
fig = go.Figure()
|
| 380 |
+
fig.update_layout(title="Cluster Hierarchy", template="plotly_dark")
|
| 381 |
+
return fig
|
| 382 |
labels_arr = np.array(labels)
|
| 383 |
centroids = np.vstack([
|
| 384 |
_centroid(embeddings[labels_arr == lbl])
|
|
|
|
| 395 |
|
| 396 |
|
| 397 |
def _chart_heatmap(labels: list[int], embeddings: np.ndarray) -> go.Figure:
|
| 398 |
+
unique = sorted(filter(lambda v: v != -1, set(labels)))
|
| 399 |
+
if not unique:
|
| 400 |
+
fig = go.Figure()
|
| 401 |
+
fig.update_layout(title="Cluster Similarity Heatmap", template="plotly_dark")
|
| 402 |
+
return fig
|
| 403 |
labels_arr = np.array(labels)
|
| 404 |
centroids = np.vstack([
|
| 405 |
_centroid(embeddings[labels_arr == lbl])
|
|
|
|
| 497 |
# ============================================================================
|
| 498 |
|
| 499 |
@tool
|
| 500 |
+
def run_bertopic_discovery(
|
| 501 |
+
run_key: str,
|
| 502 |
+
threshold: float = DISTANCE_THRESH,
|
| 503 |
+
min_cluster_size: int = HDBSCAN_MIN_CLUSTER_SIZE,
|
| 504 |
+
max_cluster_size: int = HDBSCAN_MAX_CLUSTER_SIZE,
|
| 505 |
+
min_samples: int = HDBSCAN_MIN_SAMPLES,
|
| 506 |
+
) -> dict:
|
| 507 |
"""
|
| 508 |
+
Embed sentences, cluster with UMAP + HDBSCAN, extract evidence,
|
| 509 |
and generate four Plotly charts.
|
| 510 |
|
| 511 |
Saved artefacts
|
| 512 |
---------------
|
| 513 |
+
emb.npy : (N, D) float32 L2-normalised embeddings
|
| 514 |
sent_labels.npy : (N,) int32 per-sentence cluster label [BUG 1 FIX]
|
| 515 |
summaries.json : list of cluster dicts with evidence sentences
|
| 516 |
|
| 517 |
Parameters
|
| 518 |
----------
|
| 519 |
run_key : str — "abstract" or "title" or "keywords"
|
| 520 |
+
threshold : float — legacy arg (ignored by HDBSCAN)
|
| 521 |
+
min_cluster_size : int — HDBSCAN minimum cluster size
|
| 522 |
+
max_cluster_size : int — HDBSCAN maximum cluster size
|
| 523 |
+
min_samples : int — HDBSCAN min_samples
|
| 524 |
|
| 525 |
Returns
|
| 526 |
-------
|
|
|
|
| 573 |
embeddings = _embed(sentences)
|
| 574 |
np.save(str(rdir / "emb.npy"), embeddings)
|
| 575 |
|
| 576 |
+
cluster_space = _umap_reduce(embeddings, UMAP_N_COMPONENTS_CLUSTER)
|
| 577 |
+
umap_2d = _umap_reduce(embeddings, UMAP_N_COMPONENTS_VIZ)
|
| 578 |
+
|
| 579 |
+
labels = _cluster(
|
| 580 |
+
cluster_space,
|
| 581 |
+
min_cluster_size=min_cluster_size,
|
| 582 |
+
max_cluster_size=max_cluster_size,
|
| 583 |
+
min_samples=min_samples,
|
| 584 |
+
).tolist()
|
| 585 |
+
unique_ids = sorted(filter(lambda v: v != -1, set(labels)))
|
| 586 |
|
| 587 |
# FIX BUG 1 — persist per-sentence label array so Tool 4 can build
|
| 588 |
# correct cluster masks without any guesswork or scaffolding.
|
|
|
|
| 590 |
|
| 591 |
labels_arr = np.array(labels)
|
| 592 |
|
| 593 |
+
if not unique_ids:
|
| 594 |
+
_save_json(rdir / "summaries.json", [])
|
| 595 |
+
return {
|
| 596 |
+
"run_key": run_key,
|
| 597 |
+
"n_clusters": 0,
|
| 598 |
+
"n_sentences": int(len(sentences)),
|
| 599 |
+
"threshold": threshold,
|
| 600 |
+
"min_cluster_size": int(min_cluster_size),
|
| 601 |
+
"max_cluster_size": int(max_cluster_size),
|
| 602 |
+
"min_samples": int(min_samples),
|
| 603 |
+
"chart_paths": {},
|
| 604 |
+
"summaries_path": str(rdir / "summaries.json"),
|
| 605 |
+
"embeddings_path": str(rdir / "emb.npy"),
|
| 606 |
+
"error": "HDBSCAN produced no clusters (all points labeled as noise).",
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
def _cluster_summary(cid: int) -> dict:
|
| 610 |
mask = labels_arr == cid
|
| 611 |
c_emb = embeddings[mask]
|
| 612 |
+
c_umap = umap_2d[mask]
|
| 613 |
c_sent = list(np.array(sentences)[mask])
|
| 614 |
ctroid = _centroid(c_emb)
|
| 615 |
top_idx = _top_k_indices(c_emb, ctroid, N_EVIDENCE)
|
| 616 |
+
coords = (
|
| 617 |
+
c_umap.mean(axis=0)
|
| 618 |
+
if c_umap.shape[0] > 0
|
| 619 |
+
else np.zeros(UMAP_N_COMPONENTS_VIZ, dtype=np.float32)
|
| 620 |
+
)
|
| 621 |
return {
|
| 622 |
"cluster_id": int(cid),
|
| 623 |
"size": int(mask.sum()),
|
| 624 |
+
"cx": float(coords[0]),
|
| 625 |
+
"cy": float(coords[1]),
|
| 626 |
"evidence": list(np.array(c_sent)[top_idx]),
|
| 627 |
}
|
| 628 |
|
|
|
|
| 641 |
"n_clusters": int(len(unique_ids)),
|
| 642 |
"n_sentences": int(len(sentences)),
|
| 643 |
"threshold": threshold,
|
| 644 |
+
"min_cluster_size": int(min_cluster_size),
|
| 645 |
+
"max_cluster_size": int(max_cluster_size),
|
| 646 |
+
"min_samples": int(min_samples),
|
| 647 |
"chart_paths": chart_paths,
|
| 648 |
"summaries_path": str(rdir / "summaries.json"),
|
| 649 |
"embeddings_path": str(rdir / "emb.npy"),
|
|
|
|
| 742 |
"groq_confidence": 0.0,
|
| 743 |
"groq_reasoning": "",
|
| 744 |
"groq_niche": False,
|
| 745 |
+
"groq_ollama_label": "",
|
| 746 |
+
"groq_ollama_category": "",
|
| 747 |
+
"groq_ollama_confidence": 0.0,
|
| 748 |
+
"groq_ollama_reasoning": "",
|
| 749 |
+
"groq_ollama_niche": False,
|
| 750 |
+
"groq_gpt_label": "",
|
| 751 |
+
"groq_gpt_category": "",
|
| 752 |
+
"groq_gpt_confidence": 0.0,
|
| 753 |
+
"groq_gpt_reasoning": "",
|
| 754 |
+
"groq_gpt_niche": False,
|
| 755 |
"verification_done": False,
|
| 756 |
+
"verification_done_ollama": False,
|
| 757 |
+
"verification_done_gpt": False,
|
| 758 |
+
"verification_note": (
|
| 759 |
+
"Run VERIFY in Phase 2 to compare with Groq-Ollama and Groq-GPT labels."
|
| 760 |
+
),
|
| 761 |
}
|
| 762 |
|
| 763 |
labelled = list(map(_label_one, selected))
|
|
|
|
| 772 |
"confidence": r.get("confidence"),
|
| 773 |
"mistral_label": r.get("mistral_label", ""),
|
| 774 |
"groq_label": r.get("groq_label", ""),
|
| 775 |
+
"groq_ollama_label": r.get("groq_ollama_label", r.get("groq_label", "")),
|
| 776 |
+
"groq_gpt_label": r.get("groq_gpt_label", ""),
|
| 777 |
"size": r.get("size"),
|
| 778 |
"niche": r.get("niche", False),
|
| 779 |
},
|
|
|
|
| 787 |
"total_clusters": len(summaries),
|
| 788 |
"selected_clusters": len(selected),
|
| 789 |
"skipped_clusters": max(0, len(summaries) - len(selected)),
|
| 790 |
+
"groq_enabled": _groq_ollama_enabled() and _groq_gpt_enabled(),
|
| 791 |
+
"mode_note": "Single-model labeling complete (Mistral). Send VERIFY in Phase 2 to run Groq-Ollama and Groq-GPT verification.",
|
| 792 |
"labels_preview": preview,
|
| 793 |
}
|
| 794 |
|
|
|
|
| 797 |
def verify_topic_labels_with_groq(run_key: str) -> dict:
|
| 798 |
"""
|
| 799 |
Run Groq topic labeling for already-labeled topics and append comparison fields
|
| 800 |
+
into labels.json so UI review table can show Mistral vs Groq-Ollama vs Groq-GPT labels.
|
| 801 |
|
| 802 |
Parameters
|
| 803 |
----------
|
|
|
|
| 812 |
labels_path = rdir / "labels.json"
|
| 813 |
summaries_path = rdir / "summaries.json"
|
| 814 |
|
| 815 |
+
if not _groq_ollama_enabled() or not _groq_gpt_enabled():
|
| 816 |
return {
|
| 817 |
"run_key": run_key,
|
| 818 |
"labels_path": str(labels_path),
|
| 819 |
"verified_count": 0,
|
| 820 |
"labels_preview": [],
|
| 821 |
"error": (
|
| 822 |
+
"GROQ_API_KEY or Groq model config is missing, or langchain-groq is unavailable. "
|
| 823 |
+
"Set GROQ_API_KEY and GROQ_GPT_MODEL_NAME (and optionally GROQ_OLLAMA_MODEL_NAME) "
|
| 824 |
+
"and install requirements to use VERIFY."
|
| 825 |
),
|
| 826 |
}
|
| 827 |
|
|
|
|
| 861 |
labels_data,
|
| 862 |
))
|
| 863 |
|
| 864 |
+
chain_groq_ollama = _LABEL_PROMPT | _llm_groq(GROQ_OLLAMA_MODEL_NAME) | JsonOutputParser()
|
| 865 |
+
chain_groq_gpt = _LABEL_PROMPT | _llm_groq(GROQ_GPT_MODEL_NAME) | JsonOutputParser()
|
| 866 |
|
| 867 |
def _evidence_block(summary: dict) -> str:
|
| 868 |
return "\n".join(
|
|
|
|
| 870 |
for i, s in enumerate(summary.get("evidence", []))
|
| 871 |
)
|
| 872 |
|
| 873 |
+
def _label_with_groq(row: dict) -> tuple[int, dict, dict]:
|
| 874 |
cid = int(row.get("cluster_id", -1))
|
| 875 |
summary = summary_by_id[cid]
|
| 876 |
+
payload = {
|
| 877 |
"cluster_id": summary["cluster_id"],
|
| 878 |
"size": summary["size"],
|
| 879 |
"evidence": _evidence_block(summary),
|
| 880 |
+
}
|
| 881 |
+
groq_ollama = _invoke_with_retries(lambda: chain_groq_ollama.invoke(payload))
|
| 882 |
+
groq_gpt = _invoke_with_retries(lambda: chain_groq_gpt.invoke(payload))
|
| 883 |
+
return cid, groq_ollama, groq_gpt
|
| 884 |
|
| 885 |
groq_pairs = list(map(_label_with_groq, target_rows))
|
| 886 |
+
groq_ollama_by_id = {cid: data for cid, data, _ in groq_pairs}
|
| 887 |
+
groq_gpt_by_id = {cid: data for cid, _, data in groq_pairs}
|
| 888 |
|
| 889 |
def _merge_row(row: dict) -> dict:
|
| 890 |
cid = int(row.get("cluster_id", -1))
|
| 891 |
+
groq_ollama = groq_ollama_by_id.get(cid, {})
|
| 892 |
+
groq_gpt = groq_gpt_by_id.get(cid, {})
|
| 893 |
+
has_groq_ollama = bool(groq_ollama)
|
| 894 |
+
has_groq_gpt = bool(groq_gpt)
|
| 895 |
mistral_label = str(row.get("mistral_label") or row.get("label", "")).strip()
|
| 896 |
+
groq_ollama_label = str(groq_ollama.get("label", "")).strip()
|
| 897 |
+
groq_gpt_label = str(groq_gpt.get("label", "")).strip()
|
| 898 |
is_agreement = (
|
| 899 |
+
all([mistral_label, groq_ollama_label, groq_gpt_label])
|
| 900 |
+
and mistral_label.lower() == groq_ollama_label.lower()
|
| 901 |
+
and mistral_label.lower() == groq_gpt_label.lower()
|
| 902 |
)
|
| 903 |
|
| 904 |
return {
|
|
|
|
| 911 |
),
|
| 912 |
"mistral_reasoning": row.get("mistral_reasoning") or row.get("reasoning", ""),
|
| 913 |
"mistral_niche": bool(row.get("mistral_niche", row.get("niche", False))),
|
| 914 |
+
"groq_label": groq_ollama_label,
|
| 915 |
+
"groq_category": groq_ollama.get("category", ""),
|
| 916 |
+
"groq_confidence": _to_float(groq_ollama.get("confidence"), 0.0),
|
| 917 |
+
"groq_reasoning": groq_ollama.get("reasoning", ""),
|
| 918 |
+
"groq_niche": bool(groq_ollama.get("niche", False)),
|
| 919 |
+
"groq_ollama_label": groq_ollama_label,
|
| 920 |
+
"groq_ollama_category": groq_ollama.get("category", ""),
|
| 921 |
+
"groq_ollama_confidence": _to_float(groq_ollama.get("confidence"), 0.0),
|
| 922 |
+
"groq_ollama_reasoning": groq_ollama.get("reasoning", ""),
|
| 923 |
+
"groq_ollama_niche": bool(groq_ollama.get("niche", False)),
|
| 924 |
+
"groq_gpt_label": groq_gpt_label,
|
| 925 |
+
"groq_gpt_category": groq_gpt.get("category", ""),
|
| 926 |
+
"groq_gpt_confidence": _to_float(groq_gpt.get("confidence"), 0.0),
|
| 927 |
+
"groq_gpt_reasoning": groq_gpt.get("reasoning", ""),
|
| 928 |
+
"groq_gpt_niche": bool(groq_gpt.get("niche", False)),
|
| 929 |
+
"verification_done": has_groq_ollama and has_groq_gpt,
|
| 930 |
+
"verification_done_ollama": has_groq_ollama,
|
| 931 |
+
"verification_done_gpt": has_groq_gpt,
|
| 932 |
"verification_note": (
|
| 933 |
+
"Mistral, Groq-Ollama, and Groq-GPT labels match."
|
| 934 |
if is_agreement
|
| 935 |
+
else "Model labels differ. Review before approval."
|
| 936 |
)
|
| 937 |
+
if has_groq_ollama and has_groq_gpt
|
| 938 |
else "Groq labeling unavailable for this topic.",
|
| 939 |
}
|
| 940 |
|
|
|
|
| 947 |
lambda r: {
|
| 948 |
"cluster_id": r.get("cluster_id"),
|
| 949 |
"mistral_label": r.get("mistral_label", ""),
|
| 950 |
+
"groq_ollama_label": r.get("groq_ollama_label", r.get("groq_label", "")),
|
| 951 |
+
"groq_gpt_label": r.get("groq_gpt_label", ""),
|
| 952 |
"verification_note": r.get("verification_note", ""),
|
| 953 |
},
|
| 954 |
verified_rows[:MAX_TOOL_RETURN_PREVIEW],
|
| 955 |
))
|
| 956 |
|
| 957 |
+
verified_count = sum(
|
| 958 |
+
1
|
| 959 |
+
for row in verified_rows
|
| 960 |
+
if row.get("groq_ollama_label") and row.get("groq_gpt_label")
|
| 961 |
+
)
|
| 962 |
|
| 963 |
return {
|
| 964 |
"run_key": run_key,
|
|
|
|
| 1164 |
run_key, taxonomy_path, verification_path,
|
| 1165 |
verified_count, mapping_preview
|
| 1166 |
"""
|
| 1167 |
+
if not _groq_ollama_enabled():
|
| 1168 |
return {
|
| 1169 |
"run_key": run_key,
|
| 1170 |
"taxonomy_path": str(_run_dir(run_key) / "taxonomy_map.json"),
|
|
|
|
| 1208 |
taxonomy_map = _load_json(taxonomy_path)
|
| 1209 |
taxonomy_str = "\n".join(f" - {cat}" for cat in PAJAIS_TAXONOMY)
|
| 1210 |
|
| 1211 |
+
chain_groq = _TAXONOMY_PROMPT | _llm_groq(GROQ_OLLAMA_MODEL_NAME) | JsonOutputParser()
|
| 1212 |
|
| 1213 |
def _map_theme_with_groq(theme: dict) -> dict:
|
| 1214 |
return _invoke_with_retries(lambda: chain_groq.invoke({
|