atharvthite05 commited on
Commit
60bc73f
·
verified ·
1 Parent(s): f376c71

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +176 -56
tools.py CHANGED
@@ -46,10 +46,11 @@ import pandas as pd
46
  import plotly.express as px
47
  import plotly.graph_objects as go
48
  import plotly.figure_factory as ff
49
- from sklearn.cluster import AgglomerativeClustering
50
  from sklearn.metrics.pairwise import cosine_similarity
51
  from sklearn.preprocessing import normalize
52
  from sentence_transformers import SentenceTransformer
 
 
53
 
54
  from langchain_core.tools import tool
55
  from langchain_core.prompts import PromptTemplate
@@ -69,7 +70,9 @@ MISTRAL_API_KEY: str = os.environ.get("MISTRAL_API_KEY", "")
69
  MODEL_NAME: str = "mistral-small-latest"
70
  GROQ_API_KEY: str = os.environ.get("GROQ_API_KEY", "")
71
  GROQ_MODEL_NAME: str = os.environ.get("GROQ_MODEL_NAME", "llama-3.3-70b-versatile")
72
- EMBED_MODEL: str = "all-MiniLM-L6-v2"
 
 
73
  BASE_DIR: Path = Path(__file__).resolve().parent
74
  OUTPUT_DIR: Path = BASE_DIR / "outputs"
75
  N_EVIDENCE: int = 5 # sentences kept per cluster centroid
@@ -78,10 +81,17 @@ RANDOM_SEED: int = 42
78
  LLM_TIMEOUT_S: int = 45
79
  LLM_MAX_RETRIES: int = 3
80
  MAX_LABEL_CLUSTERS: int = 60
81
- MIN_CLUSTER_SIZE_FOR_LABEL: int = 3
82
  MAX_TOOL_RETURN_PREVIEW: int = 12
83
  PROVIDER_RETRY_ATTEMPTS: int = 3
84
  PROVIDER_RETRY_BASE_DELAY_S: float = 1.5
 
 
 
 
 
 
 
85
 
86
  # Run configurations — keys map to source columns
87
  RUN_CONFIGS: dict[str, list[str]] = {
@@ -198,18 +208,33 @@ def _texts_for_candidates(df: pd.DataFrame, candidates: list[str]) -> tuple[list
198
 
199
 
200
  def _embed(sentences: list[str]) -> np.ndarray:
201
- """Encode sentences to L2-normalised 384-d vectors."""
202
- model = SentenceTransformer(EMBED_MODEL)
203
  raw = model.encode(sentences, show_progress_bar=False, batch_size=64)
204
  return normalize(raw, norm="l2") # unit-norm -> cosine = dot product
205
 
206
 
207
- def _cluster(embeddings: np.ndarray, threshold: float) -> np.ndarray:
208
- return AgglomerativeClustering(
 
 
 
209
  metric="cosine",
210
- linkage="average",
211
- distance_threshold=threshold,
212
- n_clusters=None,
 
 
 
 
 
 
 
 
 
 
 
 
213
  ).fit_predict(embeddings)
214
 
215
 
@@ -234,14 +259,14 @@ def _llm() -> ChatMistralAI:
234
  )
235
 
236
 
237
- def _llm_groq():
238
  if ChatGroq is None:
239
  raise RuntimeError(
240
  "langchain-groq is not installed. Install dependencies from requirements.txt "
241
  "to enable Groq topic-label verification."
242
  )
243
  return ChatGroq(
244
- model=GROQ_MODEL_NAME,
245
  api_key=GROQ_API_KEY,
246
  temperature=0.2,
247
  timeout=LLM_TIMEOUT_S,
@@ -249,8 +274,12 @@ def _llm_groq():
249
  )
250
 
251
 
252
- def _groq_enabled() -> bool:
253
- return bool(GROQ_API_KEY) and ChatGroq is not None
 
 
 
 
254
 
255
 
256
  def _to_float(value: object, fallback: float = 0.0) -> float:
@@ -345,7 +374,11 @@ def _chart_top_words(summaries: list[dict]) -> go.Figure:
345
 
346
 
347
  def _chart_hierarchy(labels: list[int], embeddings: np.ndarray) -> go.Figure:
348
- unique = sorted(set(labels))
 
 
 
 
349
  labels_arr = np.array(labels)
350
  centroids = np.vstack([
351
  _centroid(embeddings[labels_arr == lbl])
@@ -362,7 +395,11 @@ def _chart_hierarchy(labels: list[int], embeddings: np.ndarray) -> go.Figure:
362
 
363
 
364
  def _chart_heatmap(labels: list[int], embeddings: np.ndarray) -> go.Figure:
365
- unique = sorted(set(labels))
 
 
 
 
366
  labels_arr = np.array(labels)
367
  centroids = np.vstack([
368
  _centroid(embeddings[labels_arr == lbl])
@@ -460,21 +497,30 @@ def load_scopus_csv(filepath: str) -> dict:
460
  # ============================================================================
461
 
462
  @tool
463
- def run_bertopic_discovery(run_key: str, threshold: float = DISTANCE_THRESH) -> dict:
 
 
 
 
 
 
464
  """
465
- Embed sentences, cluster with AgglomerativeClustering, extract evidence,
466
  and generate four Plotly charts.
467
 
468
  Saved artefacts
469
  ---------------
470
- emb.npy : (N, 384) float32 L2-normalised embeddings
471
  sent_labels.npy : (N,) int32 per-sentence cluster label [BUG 1 FIX]
472
  summaries.json : list of cluster dicts with evidence sentences
473
 
474
  Parameters
475
  ----------
476
  run_key : str — "abstract" or "title" or "keywords"
477
- threshold : float — cosine distance threshold for AgglomerativeClustering
 
 
 
478
 
479
  Returns
480
  -------
@@ -527,8 +573,16 @@ def run_bertopic_discovery(run_key: str, threshold: float = DISTANCE_THRESH) ->
527
  embeddings = _embed(sentences)
528
  np.save(str(rdir / "emb.npy"), embeddings)
529
 
530
- labels = _cluster(embeddings, threshold).tolist()
531
- unique_ids = sorted(set(labels))
 
 
 
 
 
 
 
 
532
 
533
  # FIX BUG 1 — persist per-sentence label array so Tool 4 can build
534
  # correct cluster masks without any guesswork or scaffolding.
@@ -536,17 +590,39 @@ def run_bertopic_discovery(run_key: str, threshold: float = DISTANCE_THRESH) ->
536
 
537
  labels_arr = np.array(labels)
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  def _cluster_summary(cid: int) -> dict:
540
  mask = labels_arr == cid
541
  c_emb = embeddings[mask]
 
542
  c_sent = list(np.array(sentences)[mask])
543
  ctroid = _centroid(c_emb)
544
  top_idx = _top_k_indices(c_emb, ctroid, N_EVIDENCE)
 
 
 
 
 
545
  return {
546
  "cluster_id": int(cid),
547
  "size": int(mask.sum()),
548
- "cx": float(ctroid[0]),
549
- "cy": float(ctroid[1]),
550
  "evidence": list(np.array(c_sent)[top_idx]),
551
  }
552
 
@@ -565,6 +641,9 @@ def run_bertopic_discovery(run_key: str, threshold: float = DISTANCE_THRESH) ->
565
  "n_clusters": int(len(unique_ids)),
566
  "n_sentences": int(len(sentences)),
567
  "threshold": threshold,
 
 
 
568
  "chart_paths": chart_paths,
569
  "summaries_path": str(rdir / "summaries.json"),
570
  "embeddings_path": str(rdir / "emb.npy"),
@@ -663,8 +742,22 @@ def label_topics_with_llm(run_key: str) -> dict:
663
  "groq_confidence": 0.0,
664
  "groq_reasoning": "",
665
  "groq_niche": False,
 
 
 
 
 
 
 
 
 
 
666
  "verification_done": False,
667
- "verification_note": "Run VERIFY in Phase 2 to compare with Groq labels.",
 
 
 
 
668
  }
669
 
670
  labelled = list(map(_label_one, selected))
@@ -679,6 +772,8 @@ def label_topics_with_llm(run_key: str) -> dict:
679
  "confidence": r.get("confidence"),
680
  "mistral_label": r.get("mistral_label", ""),
681
  "groq_label": r.get("groq_label", ""),
 
 
682
  "size": r.get("size"),
683
  "niche": r.get("niche", False),
684
  },
@@ -692,8 +787,8 @@ def label_topics_with_llm(run_key: str) -> dict:
692
  "total_clusters": len(summaries),
693
  "selected_clusters": len(selected),
694
  "skipped_clusters": max(0, len(summaries) - len(selected)),
695
- "groq_enabled": _groq_enabled(),
696
- "mode_note": "Single-model labeling complete (Mistral). Send VERIFY in Phase 2 to run Groq verification.",
697
  "labels_preview": preview,
698
  }
699
 
@@ -702,7 +797,7 @@ def label_topics_with_llm(run_key: str) -> dict:
702
  def verify_topic_labels_with_groq(run_key: str) -> dict:
703
  """
704
  Run Groq topic labeling for already-labeled topics and append comparison fields
705
- into labels.json so UI review table can show both Mistral and Groq labels.
706
 
707
  Parameters
708
  ----------
@@ -717,15 +812,16 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
717
  labels_path = rdir / "labels.json"
718
  summaries_path = rdir / "summaries.json"
719
 
720
- if not _groq_enabled():
721
  return {
722
  "run_key": run_key,
723
  "labels_path": str(labels_path),
724
  "verified_count": 0,
725
  "labels_preview": [],
726
  "error": (
727
- "GROQ_API_KEY is missing or langchain-groq is unavailable. "
728
- "Set GROQ_API_KEY and install requirements to use VERIFY."
 
729
  ),
730
  }
731
 
@@ -765,7 +861,8 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
765
  labels_data,
766
  ))
767
 
768
- chain_groq = _LABEL_PROMPT | _llm_groq() | JsonOutputParser()
 
769
 
770
  def _evidence_block(summary: dict) -> str:
771
  return "\n".join(
@@ -773,29 +870,35 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
773
  for i, s in enumerate(summary.get("evidence", []))
774
  )
775
 
776
- def _label_with_groq(row: dict) -> tuple[int, dict]:
777
  cid = int(row.get("cluster_id", -1))
778
  summary = summary_by_id[cid]
779
- result = _invoke_with_retries(lambda: chain_groq.invoke({
780
  "cluster_id": summary["cluster_id"],
781
  "size": summary["size"],
782
  "evidence": _evidence_block(summary),
783
- }))
784
- return cid, result
 
 
785
 
786
  groq_pairs = list(map(_label_with_groq, target_rows))
787
- groq_by_id = {cid: data for cid, data in groq_pairs}
 
788
 
789
  def _merge_row(row: dict) -> dict:
790
  cid = int(row.get("cluster_id", -1))
791
- groq = groq_by_id.get(cid, {})
792
- has_groq = bool(groq)
 
 
793
  mistral_label = str(row.get("mistral_label") or row.get("label", "")).strip()
794
- groq_label = str(groq.get("label", "")).strip()
 
795
  is_agreement = (
796
- mistral_label.lower() == groq_label.lower()
797
- if has_groq and mistral_label and groq_label
798
- else False
799
  )
800
 
801
  return {
@@ -808,18 +911,30 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
808
  ),
809
  "mistral_reasoning": row.get("mistral_reasoning") or row.get("reasoning", ""),
810
  "mistral_niche": bool(row.get("mistral_niche", row.get("niche", False))),
811
- "groq_label": groq.get("label", ""),
812
- "groq_category": groq.get("category", ""),
813
- "groq_confidence": _to_float(groq.get("confidence"), 0.0),
814
- "groq_reasoning": groq.get("reasoning", ""),
815
- "groq_niche": bool(groq.get("niche", False)),
816
- "verification_done": has_groq,
 
 
 
 
 
 
 
 
 
 
 
 
817
  "verification_note": (
818
- "Mistral and Groq labels match."
819
  if is_agreement
820
- else "Mistral and Groq labels differ. Review before approval."
821
  )
822
- if has_groq
823
  else "Groq labeling unavailable for this topic.",
824
  }
825
 
@@ -832,13 +947,18 @@ def verify_topic_labels_with_groq(run_key: str) -> dict:
832
  lambda r: {
833
  "cluster_id": r.get("cluster_id"),
834
  "mistral_label": r.get("mistral_label", ""),
835
- "groq_label": r.get("groq_label", ""),
 
836
  "verification_note": r.get("verification_note", ""),
837
  },
838
  verified_rows[:MAX_TOOL_RETURN_PREVIEW],
839
  ))
840
 
841
- verified_count = sum(1 for row in verified_rows if row.get("groq_label"))
 
 
 
 
842
 
843
  return {
844
  "run_key": run_key,
@@ -1044,7 +1164,7 @@ def verify_taxonomy_mapping_with_groq(run_key: str) -> dict:
1044
  run_key, taxonomy_path, verification_path,
1045
  verified_count, mapping_preview
1046
  """
1047
- if not _groq_enabled():
1048
  return {
1049
  "run_key": run_key,
1050
  "taxonomy_path": str(_run_dir(run_key) / "taxonomy_map.json"),
@@ -1088,7 +1208,7 @@ def verify_taxonomy_mapping_with_groq(run_key: str) -> dict:
1088
  taxonomy_map = _load_json(taxonomy_path)
1089
  taxonomy_str = "\n".join(f" - {cat}" for cat in PAJAIS_TAXONOMY)
1090
 
1091
- chain_groq = _TAXONOMY_PROMPT | _llm_groq() | JsonOutputParser()
1092
 
1093
  def _map_theme_with_groq(theme: dict) -> dict:
1094
  return _invoke_with_retries(lambda: chain_groq.invoke({
 
46
  import plotly.express as px
47
  import plotly.graph_objects as go
48
  import plotly.figure_factory as ff
 
49
  from sklearn.metrics.pairwise import cosine_similarity
50
  from sklearn.preprocessing import normalize
51
  from sentence_transformers import SentenceTransformer
52
+ import hdbscan
53
+ import umap
54
 
55
  from langchain_core.tools import tool
56
  from langchain_core.prompts import PromptTemplate
 
70
  MODEL_NAME: str = "mistral-small-latest"
71
  GROQ_API_KEY: str = os.environ.get("GROQ_API_KEY", "")
72
  GROQ_MODEL_NAME: str = os.environ.get("GROQ_MODEL_NAME", "llama-3.3-70b-versatile")
73
+ GROQ_OLLAMA_MODEL_NAME: str = os.environ.get("GROQ_OLLAMA_MODEL_NAME", "llama-3.3-70b-versatile")
74
+ GROQ_GPT_MODEL_NAME: str = os.environ.get("GROQ_GPT_MODEL_NAME", "openai/gpt-oss-120b")
75
+ EMBED_MODEL: str = "allenai/specter2_base"
76
  BASE_DIR: Path = Path(__file__).resolve().parent
77
  OUTPUT_DIR: Path = BASE_DIR / "outputs"
78
  N_EVIDENCE: int = 5 # sentences kept per cluster centroid
 
81
  LLM_TIMEOUT_S: int = 45
82
  LLM_MAX_RETRIES: int = 3
83
  MAX_LABEL_CLUSTERS: int = 60
84
+ MIN_CLUSTER_SIZE_FOR_LABEL: int = 20
85
  MAX_TOOL_RETURN_PREVIEW: int = 12
86
  PROVIDER_RETRY_ATTEMPTS: int = 3
87
  PROVIDER_RETRY_BASE_DELAY_S: float = 1.5
88
+ HDBSCAN_MIN_CLUSTER_SIZE: int = 20
89
+ HDBSCAN_MIN_SAMPLES: int = 5
90
+ HDBSCAN_MAX_CLUSTER_SIZE: int = 120
91
+ UMAP_N_NEIGHBORS: int = 15
92
+ UMAP_MIN_DIST: float = 0.0
93
+ UMAP_N_COMPONENTS_CLUSTER: int = 5
94
+ UMAP_N_COMPONENTS_VIZ: int = 2
95
 
96
  # Run configurations — keys map to source columns
97
  RUN_CONFIGS: dict[str, list[str]] = {
 
208
 
209
 
210
  def _embed(sentences: list[str]) -> np.ndarray:
211
+ """Encode sentences to L2-normalised SPECTER2 vectors."""
212
+ model = SentenceTransformer(EMBED_MODEL, trust_remote_code=True)
213
  raw = model.encode(sentences, show_progress_bar=False, batch_size=64)
214
  return normalize(raw, norm="l2") # unit-norm -> cosine = dot product
215
 
216
 
217
+ def _umap_reduce(embeddings: np.ndarray, n_components: int) -> np.ndarray:
218
+ reducer = umap.UMAP(
219
+ n_neighbors=UMAP_N_NEIGHBORS,
220
+ min_dist=UMAP_MIN_DIST,
221
+ n_components=n_components,
222
  metric="cosine",
223
+ random_state=RANDOM_SEED,
224
+ )
225
+ return reducer.fit_transform(embeddings)
226
+
227
+
228
+ def _cluster(embeddings: np.ndarray,
229
+ min_cluster_size: int,
230
+ max_cluster_size: int,
231
+ min_samples: int) -> np.ndarray:
232
+ return hdbscan.HDBSCAN(
233
+ min_cluster_size=min_cluster_size,
234
+ min_samples=min_samples,
235
+ metric="euclidean",
236
+ cluster_selection_method="eom",
237
+ max_cluster_size=max_cluster_size,
238
  ).fit_predict(embeddings)
239
 
240
 
 
259
  )
260
 
261
 
262
+ def _llm_groq(model_name: str):
263
  if ChatGroq is None:
264
  raise RuntimeError(
265
  "langchain-groq is not installed. Install dependencies from requirements.txt "
266
  "to enable Groq topic-label verification."
267
  )
268
  return ChatGroq(
269
+ model=model_name,
270
  api_key=GROQ_API_KEY,
271
  temperature=0.2,
272
  timeout=LLM_TIMEOUT_S,
 
274
  )
275
 
276
 
277
+ def _groq_ollama_enabled() -> bool:
278
+ return bool(GROQ_API_KEY) and ChatGroq is not None and bool(GROQ_OLLAMA_MODEL_NAME)
279
+
280
+
281
+ def _groq_gpt_enabled() -> bool:
282
+ return bool(GROQ_API_KEY) and ChatGroq is not None and bool(GROQ_GPT_MODEL_NAME)
283
 
284
 
285
  def _to_float(value: object, fallback: float = 0.0) -> float:
 
374
 
375
 
376
  def _chart_hierarchy(labels: list[int], embeddings: np.ndarray) -> go.Figure:
377
+ unique = sorted(filter(lambda v: v != -1, set(labels)))
378
+ if not unique:
379
+ fig = go.Figure()
380
+ fig.update_layout(title="Cluster Hierarchy", template="plotly_dark")
381
+ return fig
382
  labels_arr = np.array(labels)
383
  centroids = np.vstack([
384
  _centroid(embeddings[labels_arr == lbl])
 
395
 
396
 
397
  def _chart_heatmap(labels: list[int], embeddings: np.ndarray) -> go.Figure:
398
+ unique = sorted(filter(lambda v: v != -1, set(labels)))
399
+ if not unique:
400
+ fig = go.Figure()
401
+ fig.update_layout(title="Cluster Similarity Heatmap", template="plotly_dark")
402
+ return fig
403
  labels_arr = np.array(labels)
404
  centroids = np.vstack([
405
  _centroid(embeddings[labels_arr == lbl])
 
497
  # ============================================================================
498
 
499
  @tool
500
+ def run_bertopic_discovery(
501
+ run_key: str,
502
+ threshold: float = DISTANCE_THRESH,
503
+ min_cluster_size: int = HDBSCAN_MIN_CLUSTER_SIZE,
504
+ max_cluster_size: int = HDBSCAN_MAX_CLUSTER_SIZE,
505
+ min_samples: int = HDBSCAN_MIN_SAMPLES,
506
+ ) -> dict:
507
  """
508
+ Embed sentences, cluster with UMAP + HDBSCAN, extract evidence,
509
  and generate four Plotly charts.
510
 
511
  Saved artefacts
512
  ---------------
513
+ emb.npy : (N, D) float32 L2-normalised embeddings
514
  sent_labels.npy : (N,) int32 per-sentence cluster label [BUG 1 FIX]
515
  summaries.json : list of cluster dicts with evidence sentences
516
 
517
  Parameters
518
  ----------
519
  run_key : str — "abstract" or "title" or "keywords"
520
+ threshold : float — legacy arg (ignored by HDBSCAN)
521
+ min_cluster_size : int — HDBSCAN minimum cluster size
522
+ max_cluster_size : int — HDBSCAN maximum cluster size
523
+ min_samples : int — HDBSCAN min_samples
524
 
525
  Returns
526
  -------
 
573
  embeddings = _embed(sentences)
574
  np.save(str(rdir / "emb.npy"), embeddings)
575
 
576
+ cluster_space = _umap_reduce(embeddings, UMAP_N_COMPONENTS_CLUSTER)
577
+ umap_2d = _umap_reduce(embeddings, UMAP_N_COMPONENTS_VIZ)
578
+
579
+ labels = _cluster(
580
+ cluster_space,
581
+ min_cluster_size=min_cluster_size,
582
+ max_cluster_size=max_cluster_size,
583
+ min_samples=min_samples,
584
+ ).tolist()
585
+ unique_ids = sorted(filter(lambda v: v != -1, set(labels)))
586
 
587
  # FIX BUG 1 — persist per-sentence label array so Tool 4 can build
588
  # correct cluster masks without any guesswork or scaffolding.
 
590
 
591
  labels_arr = np.array(labels)
592
 
593
+ if not unique_ids:
594
+ _save_json(rdir / "summaries.json", [])
595
+ return {
596
+ "run_key": run_key,
597
+ "n_clusters": 0,
598
+ "n_sentences": int(len(sentences)),
599
+ "threshold": threshold,
600
+ "min_cluster_size": int(min_cluster_size),
601
+ "max_cluster_size": int(max_cluster_size),
602
+ "min_samples": int(min_samples),
603
+ "chart_paths": {},
604
+ "summaries_path": str(rdir / "summaries.json"),
605
+ "embeddings_path": str(rdir / "emb.npy"),
606
+ "error": "HDBSCAN produced no clusters (all points labeled as noise).",
607
+ }
608
+
609
  def _cluster_summary(cid: int) -> dict:
610
  mask = labels_arr == cid
611
  c_emb = embeddings[mask]
612
+ c_umap = umap_2d[mask]
613
  c_sent = list(np.array(sentences)[mask])
614
  ctroid = _centroid(c_emb)
615
  top_idx = _top_k_indices(c_emb, ctroid, N_EVIDENCE)
616
+ coords = (
617
+ c_umap.mean(axis=0)
618
+ if c_umap.shape[0] > 0
619
+ else np.zeros(UMAP_N_COMPONENTS_VIZ, dtype=np.float32)
620
+ )
621
  return {
622
  "cluster_id": int(cid),
623
  "size": int(mask.sum()),
624
+ "cx": float(coords[0]),
625
+ "cy": float(coords[1]),
626
  "evidence": list(np.array(c_sent)[top_idx]),
627
  }
628
 
 
641
  "n_clusters": int(len(unique_ids)),
642
  "n_sentences": int(len(sentences)),
643
  "threshold": threshold,
644
+ "min_cluster_size": int(min_cluster_size),
645
+ "max_cluster_size": int(max_cluster_size),
646
+ "min_samples": int(min_samples),
647
  "chart_paths": chart_paths,
648
  "summaries_path": str(rdir / "summaries.json"),
649
  "embeddings_path": str(rdir / "emb.npy"),
 
742
  "groq_confidence": 0.0,
743
  "groq_reasoning": "",
744
  "groq_niche": False,
745
+ "groq_ollama_label": "",
746
+ "groq_ollama_category": "",
747
+ "groq_ollama_confidence": 0.0,
748
+ "groq_ollama_reasoning": "",
749
+ "groq_ollama_niche": False,
750
+ "groq_gpt_label": "",
751
+ "groq_gpt_category": "",
752
+ "groq_gpt_confidence": 0.0,
753
+ "groq_gpt_reasoning": "",
754
+ "groq_gpt_niche": False,
755
  "verification_done": False,
756
+ "verification_done_ollama": False,
757
+ "verification_done_gpt": False,
758
+ "verification_note": (
759
+ "Run VERIFY in Phase 2 to compare with Groq-Ollama and Groq-GPT labels."
760
+ ),
761
  }
762
 
763
  labelled = list(map(_label_one, selected))
 
772
  "confidence": r.get("confidence"),
773
  "mistral_label": r.get("mistral_label", ""),
774
  "groq_label": r.get("groq_label", ""),
775
+ "groq_ollama_label": r.get("groq_ollama_label", r.get("groq_label", "")),
776
+ "groq_gpt_label": r.get("groq_gpt_label", ""),
777
  "size": r.get("size"),
778
  "niche": r.get("niche", False),
779
  },
 
787
  "total_clusters": len(summaries),
788
  "selected_clusters": len(selected),
789
  "skipped_clusters": max(0, len(summaries) - len(selected)),
790
+ "groq_enabled": _groq_ollama_enabled() and _groq_gpt_enabled(),
791
+ "mode_note": "Single-model labeling complete (Mistral). Send VERIFY in Phase 2 to run Groq-Ollama and Groq-GPT verification.",
792
  "labels_preview": preview,
793
  }
794
 
 
797
  def verify_topic_labels_with_groq(run_key: str) -> dict:
798
  """
799
  Run Groq topic labeling for already-labeled topics and append comparison fields
800
+ into labels.json so UI review table can show Mistral vs Groq-Ollama vs Groq-GPT labels.
801
 
802
  Parameters
803
  ----------
 
812
  labels_path = rdir / "labels.json"
813
  summaries_path = rdir / "summaries.json"
814
 
815
+ if not _groq_ollama_enabled() or not _groq_gpt_enabled():
816
  return {
817
  "run_key": run_key,
818
  "labels_path": str(labels_path),
819
  "verified_count": 0,
820
  "labels_preview": [],
821
  "error": (
822
+ "GROQ_API_KEY or Groq model config is missing, or langchain-groq is unavailable. "
823
+ "Set GROQ_API_KEY and GROQ_GPT_MODEL_NAME (and optionally GROQ_OLLAMA_MODEL_NAME) "
824
+ "and install requirements to use VERIFY."
825
  ),
826
  }
827
 
 
861
  labels_data,
862
  ))
863
 
864
+ chain_groq_ollama = _LABEL_PROMPT | _llm_groq(GROQ_OLLAMA_MODEL_NAME) | JsonOutputParser()
865
+ chain_groq_gpt = _LABEL_PROMPT | _llm_groq(GROQ_GPT_MODEL_NAME) | JsonOutputParser()
866
 
867
  def _evidence_block(summary: dict) -> str:
868
  return "\n".join(
 
870
  for i, s in enumerate(summary.get("evidence", []))
871
  )
872
 
873
+ def _label_with_groq(row: dict) -> tuple[int, dict, dict]:
874
  cid = int(row.get("cluster_id", -1))
875
  summary = summary_by_id[cid]
876
+ payload = {
877
  "cluster_id": summary["cluster_id"],
878
  "size": summary["size"],
879
  "evidence": _evidence_block(summary),
880
+ }
881
+ groq_ollama = _invoke_with_retries(lambda: chain_groq_ollama.invoke(payload))
882
+ groq_gpt = _invoke_with_retries(lambda: chain_groq_gpt.invoke(payload))
883
+ return cid, groq_ollama, groq_gpt
884
 
885
  groq_pairs = list(map(_label_with_groq, target_rows))
886
+ groq_ollama_by_id = {cid: data for cid, data, _ in groq_pairs}
887
+ groq_gpt_by_id = {cid: data for cid, _, data in groq_pairs}
888
 
889
  def _merge_row(row: dict) -> dict:
890
  cid = int(row.get("cluster_id", -1))
891
+ groq_ollama = groq_ollama_by_id.get(cid, {})
892
+ groq_gpt = groq_gpt_by_id.get(cid, {})
893
+ has_groq_ollama = bool(groq_ollama)
894
+ has_groq_gpt = bool(groq_gpt)
895
  mistral_label = str(row.get("mistral_label") or row.get("label", "")).strip()
896
+ groq_ollama_label = str(groq_ollama.get("label", "")).strip()
897
+ groq_gpt_label = str(groq_gpt.get("label", "")).strip()
898
  is_agreement = (
899
+ all([mistral_label, groq_ollama_label, groq_gpt_label])
900
+ and mistral_label.lower() == groq_ollama_label.lower()
901
+ and mistral_label.lower() == groq_gpt_label.lower()
902
  )
903
 
904
  return {
 
911
  ),
912
  "mistral_reasoning": row.get("mistral_reasoning") or row.get("reasoning", ""),
913
  "mistral_niche": bool(row.get("mistral_niche", row.get("niche", False))),
914
+ "groq_label": groq_ollama_label,
915
+ "groq_category": groq_ollama.get("category", ""),
916
+ "groq_confidence": _to_float(groq_ollama.get("confidence"), 0.0),
917
+ "groq_reasoning": groq_ollama.get("reasoning", ""),
918
+ "groq_niche": bool(groq_ollama.get("niche", False)),
919
+ "groq_ollama_label": groq_ollama_label,
920
+ "groq_ollama_category": groq_ollama.get("category", ""),
921
+ "groq_ollama_confidence": _to_float(groq_ollama.get("confidence"), 0.0),
922
+ "groq_ollama_reasoning": groq_ollama.get("reasoning", ""),
923
+ "groq_ollama_niche": bool(groq_ollama.get("niche", False)),
924
+ "groq_gpt_label": groq_gpt_label,
925
+ "groq_gpt_category": groq_gpt.get("category", ""),
926
+ "groq_gpt_confidence": _to_float(groq_gpt.get("confidence"), 0.0),
927
+ "groq_gpt_reasoning": groq_gpt.get("reasoning", ""),
928
+ "groq_gpt_niche": bool(groq_gpt.get("niche", False)),
929
+ "verification_done": has_groq_ollama and has_groq_gpt,
930
+ "verification_done_ollama": has_groq_ollama,
931
+ "verification_done_gpt": has_groq_gpt,
932
  "verification_note": (
933
+ "Mistral, Groq-Ollama, and Groq-GPT labels match."
934
  if is_agreement
935
+ else "Model labels differ. Review before approval."
936
  )
937
+ if has_groq_ollama and has_groq_gpt
938
  else "Groq labeling unavailable for this topic.",
939
  }
940
 
 
947
  lambda r: {
948
  "cluster_id": r.get("cluster_id"),
949
  "mistral_label": r.get("mistral_label", ""),
950
+ "groq_ollama_label": r.get("groq_ollama_label", r.get("groq_label", "")),
951
+ "groq_gpt_label": r.get("groq_gpt_label", ""),
952
  "verification_note": r.get("verification_note", ""),
953
  },
954
  verified_rows[:MAX_TOOL_RETURN_PREVIEW],
955
  ))
956
 
957
+ verified_count = sum(
958
+ 1
959
+ for row in verified_rows
960
+ if row.get("groq_ollama_label") and row.get("groq_gpt_label")
961
+ )
962
 
963
  return {
964
  "run_key": run_key,
 
1164
  run_key, taxonomy_path, verification_path,
1165
  verified_count, mapping_preview
1166
  """
1167
+ if not _groq_ollama_enabled():
1168
  return {
1169
  "run_key": run_key,
1170
  "taxonomy_path": str(_run_dir(run_key) / "taxonomy_map.json"),
 
1208
  taxonomy_map = _load_json(taxonomy_path)
1209
  taxonomy_str = "\n".join(f" - {cat}" for cat in PAJAIS_TAXONOMY)
1210
 
1211
+ chain_groq = _TAXONOMY_PROMPT | _llm_groq(GROQ_OLLAMA_MODEL_NAME) | JsonOutputParser()
1212
 
1213
  def _map_theme_with_groq(theme: dict) -> dict:
1214
  return _invoke_with_retries(lambda: chain_groq.invoke({