anujjuna commited on
Commit
34345fd
·
verified ·
1 Parent(s): ee50027

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +218 -371
agent.py CHANGED
@@ -1,414 +1,261 @@
1
  """
2
- agent.py
3
- --------
4
- LLM Council labelling module (§3.5).
5
-
6
- Three independent LLMs label each cluster, producing Sheets 1–3.
7
- Sheet 4 consolidates with Triple / Two / Single agreement tags.
8
- Disagreement clusters get a fourth-round defence prompt.
9
- Labels not grounded in keyphrases are rejected.
10
  """
11
-
12
  from __future__ import annotations
13
- import json
14
- import logging
15
- import os
16
- import re
17
- import time
18
  from dataclasses import dataclass, field, asdict
19
- from typing import Optional
20
- import pandas as pd
21
- import numpy as np
22
- import requests
23
  from groq import Groq
 
24
 
25
- # ---------------------------------------------------------------------------
26
- # Logging
27
- # ---------------------------------------------------------------------------
28
  logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
29
  logger = logging.getLogger(__name__)
30
 
31
- # ---------------------------------------------------------------------------
32
- # Constants
33
- # ---------------------------------------------------------------------------
34
- GROQ_MODEL = "llama-3.1-8b-instant"
35
  MISTRAL_MODEL = "mistral-small-latest"
36
 
37
- DEFAULT_TAXONOMY = [
38
- "Artificial Intelligence", "Machine Learning",
39
- "Natural Language Processing", "Computer Vision",
40
- "Information Systems", "Healthcare & Bioinformatics",
41
- "Finance & Economics", "Cybersecurity",
42
- "Human-Computer Interaction", "Robotics & Automation",
43
- "Education Technology", "Environmental Science",
44
- "Social Sciences", "Data Engineering", "Other",
45
- ]
46
-
47
  # ---------------------------------------------------------------------------
48
- # Data classes
49
  # ---------------------------------------------------------------------------
50
- @dataclass
51
- class LLMVote:
52
- """One LLM's response for one cluster."""
53
- llm_name: str
54
- label: str = ""
55
- description: str = ""
56
- pacis_match: str = ""
57
- confidence: float = 0.0
58
- raw: dict = field(default_factory=dict)
59
-
60
-
61
- @dataclass
62
- class ClusterInterpretation:
63
- """Consolidated interpretation for a single cluster."""
64
- cluster_id: int
65
- final_label: str = ""
66
- final_description: str = ""
67
- final_pacis_match: str = ""
68
- final_confidence: float = 0.0
69
- agreement: str = "" # Triple / Two / Single
70
- sheet1: dict = field(default_factory=dict)
71
- sheet2: dict = field(default_factory=dict)
72
- sheet3: dict = field(default_factory=dict)
73
- defence: dict = field(default_factory=dict) # 4th-round if needed
74
- keyphrases: list = field(default_factory=list)
75
- strong_count: int = 0
76
- weak_count: int = 0
77
- paper_count: int = 0
78
- grounding_check: dict = field(default_factory=dict)
79
-
80
 
81
  # ---------------------------------------------------------------------------
82
- # API Clients
83
  # ---------------------------------------------------------------------------
84
- def build_groq_client(api_key: Optional[str] = None):
85
- key = api_key or os.getenv("GROQ_API_KEY")
86
- if not key:
87
- raise ValueError("No Groq API key.")
88
- return Groq(api_key=key, max_retries=0)
89
-
90
-
91
- def _call_groq(client, prompt: str) -> dict:
92
- try:
93
- r = client.chat.completions.create(
94
- model=GROQ_MODEL,
95
- messages=[{"role": "user", "content": prompt}],
96
- temperature=0.2, timeout=15,
97
- )
98
- return _parse_json(r.choices[0].message.content)
99
- except Exception as e:
100
- logger.warning("Groq failed: %s", e)
101
- return {}
102
-
103
-
104
- def _call_mistral(prompt: str, api_key: str) -> dict:
105
- if not api_key:
106
- return {}
107
  try:
108
- r = requests.post(
109
- "https://api.mistral.ai/v1/chat/completions",
110
- headers={"Authorization": f"Bearer {api_key}",
111
- "Content-Type": "application/json"},
112
- json={"model": MISTRAL_MODEL,
113
- "messages": [{"role": "user", "content": prompt}],
114
- "temperature": 0.2},
115
- timeout=15,
116
- )
117
- return _parse_json(r.json()["choices"][0]["message"]["content"])
118
- except Exception as e:
119
- logger.warning("Mistral failed: %s", e)
120
- return {}
121
 
122
-
123
- def _call_gemini(prompt: str, api_key: str) -> dict:
124
- if not api_key:
125
- return {}
126
- url = (f"https://generativelanguage.googleapis.com/v1beta/models/"
127
- f"gemini-2.5-flash:generateContent?key={api_key}")
128
  try:
129
- r = requests.post(url,
130
- headers={"Content-Type": "application/json"},
131
- json={"contents": [{"parts": [{"text": prompt}]}],
132
- "generationConfig": {"temperature": 0.2}},
133
- timeout=15)
134
- data = r.json()
135
- if "candidates" not in data:
136
- return {}
137
- raw = data["candidates"][0]["content"]["parts"][0]["text"]
138
- return _parse_json(raw)
139
- except Exception as e:
140
- logger.warning("Gemini failed: %s", e)
141
- return {}
142
-
143
-
144
- def _parse_json(raw: str) -> dict:
145
- raw = raw.strip().replace("```json", "").replace("```", "").strip()
146
- s, e = raw.find("{"), raw.rfind("}") + 1
147
- if s != -1 and e > 0:
148
- raw = raw[s:e]
149
  try:
150
- return json.loads(raw)
151
- except Exception:
152
- return {}
153
-
 
 
 
 
 
 
154
 
155
  # ---------------------------------------------------------------------------
156
- # Prompt builders
157
  # ---------------------------------------------------------------------------
158
- def _build_label_prompt(keyphrases: list, rep_abstracts: list) -> str:
159
- kp_str = ", ".join(k if isinstance(k, str) else k[0]
160
- for k in keyphrases[:5])
161
- abs_str = " | ".join(a[:300] for a in rep_abstracts[:3])
162
- return f"""You are a research-topic classifier.
163
  A SPECTER-2 + HDBSCAN pipeline produced a topic cluster.
164
 
165
- KEYPHRASES: {kp_str}
166
- REPRESENTATIVE ABSTRACTS (truncated): {abs_str}
167
-
168
- Return ONLY valid JSON (no markdown, no other text):
169
- {{
170
- "label": "<concise 5-8 word topic label>",
171
- "description": "<one-sentence description of the topic>",
172
- "pacis_match": "<closest PAJAIS 2019 category, or NOVEL if none>",
173
- "confidence": <0.0-1.0 float>
174
- }}"""
175
-
176
-
177
- def _build_defence_prompt(
178
- keyphrases: list,
179
- rep_abstracts: list,
180
- votes: list[dict],
181
- ) -> str:
182
- kp_str = ", ".join(k if isinstance(k, str) else k[0]
183
- for k in keyphrases[:5])
184
- abs_str = " | ".join(a[:300] for a in rep_abstracts[:3])
185
- v_str = "\n".join(
186
- f" LLM {i+1}: label=\"{v.get('label','?')}\", "
187
- f"pacis=\"{v.get('pacis_match','?')}\""
188
- for i, v in enumerate(votes)
189
- )
190
- return f"""You are a research-topic adjudicator resolving a labelling disagreement.
191
-
192
- KEYPHRASES: {kp_str}
193
- REPRESENTATIVE ABSTRACTS: {abs_str}
194
-
195
- Three LLMs proposed different labels:
196
- {v_str}
197
-
198
- Your task: pick the single best label from the three, or synthesise a
199
- better one. Justify your choice in one sentence.
200
 
201
  Return ONLY valid JSON:
202
  {{
203
- "label": "<best 5-8 word label>",
204
- "description": "<one sentence>",
205
- "pacis_match": "<PAJAIS category or NOVEL>",
206
- "confidence": <0.0-1.0>,
207
- "reasoning": "<one sentence justification>"
208
  }}"""
209
 
210
-
211
  # ---------------------------------------------------------------------------
212
- # Grounding check reject labels not supported by keyphrases (§3.5)
213
  # ---------------------------------------------------------------------------
214
- def grounding_check(label: str, keyphrases: list) -> dict:
215
- """Non-LLM regex check: label tokens must overlap keyphrases."""
216
- if not label or not keyphrases:
217
- return {"verdict": "FAIL", "score": 0, "matched": []}
218
- label_toks = set(re.findall(r"\b[a-z]{3,}\b", label.lower()))
219
- kp_toks = set()
220
- for kp in keyphrases:
221
- phrase = kp if isinstance(kp, str) else kp[0]
222
- kp_toks.update(re.findall(r"\b[a-z]{3,}\b", phrase.lower()))
223
- noise = {"the", "and", "for", "with", "using", "based", "from", "that",
224
- "are", "this", "into", "its"}
225
- label_toks -= noise
226
- kp_toks -= noise
227
- matched = list(label_toks & kp_toks)
228
- # stem-level
229
- stems = []
230
- for lt in label_toks:
231
- for kt in kp_toks:
232
- if len(lt) >= 4 and (kt.startswith(lt[:4]) or lt.startswith(kt[:4])):
233
- stems.append(f"{lt}≈{kt}")
234
- score = min(1.0, len(matched) / max(len(label_toks), 1)
235
- + 0.15 * len(stems))
236
- verdict = "PASS" if (matched or stems) else "FAIL"
237
- return {"verdict": verdict, "score": round(score, 3),
238
- "matched": matched, "stems": stems[:5]}
239
-
240
 
241
  # ---------------------------------------------------------------------------
242
- # Core — interpret one cluster via 3-LLM council (§3.5)
243
  # ---------------------------------------------------------------------------
244
- def interpret_cluster(
245
- cluster_id: int,
246
- keyphrases: list,
247
- rep_docs: list,
248
- strong: int,
249
- weak: int,
250
- groq_client,
251
- mistral_key: str,
252
- gemini_key: str,
253
- ) -> ClusterInterpretation:
254
-
255
- prompt = _build_label_prompt(keyphrases, rep_docs)
256
-
257
- # Sheet 1 Groq / LLaMA-3.1
258
- s1 = _call_groq(groq_client, prompt)
259
- time.sleep(1)
260
- # Sheet 2 — Mistral
261
- s2 = _call_mistral(prompt, mistral_key)
262
- time.sleep(1)
263
- # Sheet 3 — Gemini
264
- s3 = _call_gemini(prompt, gemini_key)
265
-
266
- votes = [s1, s2, s3]
267
- valid = [v for v in votes if v and "label" in v]
268
-
269
- # --- Sheet 4: consolidate agreement ---
270
- labels_lower = [_clean(v.get("label", "")).lower() for v in valid]
271
- counts = {}
272
- for l in labels_lower:
273
- counts[l] = counts.get(l, 0) + 1
274
-
275
- best_label = ""
276
- agreement = "Single"
277
- defence = {}
278
-
279
- if any(c >= 3 for c in counts.values()):
280
- agreement = "Triple"
281
- winner = max(counts, key=counts.get)
282
- best_label = next(v["label"] for v in valid
283
- if _clean(v["label"]).lower() == winner)
284
- elif any(c >= 2 for c in counts.values()):
285
- agreement = "Two"
286
- winner = max(counts, key=counts.get)
287
- best_label = next(v["label"] for v in valid
288
- if _clean(v["label"]).lower() == winner)
289
- else:
290
- agreement = "Single"
291
- # Fourth-round defence prompt (§3.5)
292
- defence_prompt = _build_defence_prompt(keyphrases, rep_docs, votes)
293
- defence = _call_groq(groq_client, defence_prompt)
294
- if defence and "label" in defence:
295
- best_label = defence["label"]
296
- elif valid:
297
- best_label = valid[0]["label"]
298
-
299
- best_label = _clean(best_label)
300
-
301
- # Grounding check — reject if not supported by keyphrases
302
- gc = grounding_check(best_label, keyphrases)
303
- if gc["verdict"] == "FAIL" and valid:
304
- # Fall back to most keyphrase-grounded label
305
- scored = [(v, len(set(re.findall(r"\b[a-z]{3,}\b",
306
- v.get("label", "").lower()))
307
- & set(re.findall(r"\b[a-z]{3,}\b",
308
- " ".join(k if isinstance(k, str) else k[0]
309
- for k in keyphrases).lower()))))
310
- for v in valid]
311
- scored.sort(key=lambda x: -x[1])
312
- best_label = _clean(scored[0][0]["label"])
313
- gc = grounding_check(best_label, keyphrases)
314
- logger.info("Cluster %d: label rejected by grounding, "
315
- "fell back to '%s'", cluster_id, best_label)
316
-
317
- # Best metadata
318
- best_v = next((v for v in valid
319
- if _clean(v.get("label", "")).lower()
320
- == best_label.lower()), valid[0] if valid else {})
321
-
322
- return ClusterInterpretation(
323
- cluster_id=cluster_id,
324
- final_label=best_label,
325
- final_description=best_v.get("description", ""),
326
- final_pacis_match=best_v.get("pacis_match", ""),
327
- final_confidence=best_v.get("confidence", 0.0),
328
- agreement=agreement,
329
- sheet1=s1, sheet2=s2, sheet3=s3,
330
- defence=defence,
331
- keyphrases=[k if isinstance(k, str) else k[0]
332
- for k in keyphrases[:5]],
333
- strong_count=strong,
334
- weak_count=weak,
335
- paper_count=strong + weak,
336
- grounding_check=gc,
337
- )
338
-
339
-
340
- def _clean(s: str) -> str:
341
- s = str(s or "").replace("\n", " ").strip()
342
- s = " ".join(s.split())
343
- if len(s) > 60:
344
- s = s[:60].rsplit(" ", 1)[0] if " " in s[:60] else s[:60]
345
- return s.rstrip(" .")
346
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  # ---------------------------------------------------------------------------
349
- # Numpy-safe serialisation
350
  # ---------------------------------------------------------------------------
351
- def _convert(obj):
352
- if isinstance(obj, dict):
353
- return {k: _convert(v) for k, v in obj.items()}
354
- if isinstance(obj, list):
355
- return [_convert(v) for v in obj]
356
- if isinstance(obj, (np.integer,)):
357
- return int(obj)
358
- if isinstance(obj, (np.floating,)):
359
- return float(obj)
360
- return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  # ---------------------------------------------------------------------------
364
- # Run agent — orchestrate all clusters
365
  # ---------------------------------------------------------------------------
366
- def run_agent(
367
- topic_results: dict,
368
- groq_key: str,
369
- mistral_key: str,
370
- gemini_key: str,
371
- output_json: str = "topics.json",
372
- output_csv: str = "topics.csv",
373
- ) -> dict:
374
- client = build_groq_client(groq_key)
375
-
376
- labels_list = topic_results["labels"]
377
- keyphrases = topic_results["keyphrases"]
378
- rep_docs = topic_results["representative_docs"]
379
- membership = topic_results["membership"]
380
-
381
- cluster_ids = sorted(keyphrases.keys())
382
- interpretations = {}
383
-
384
- for cid in cluster_ids:
385
- sw = membership.get(cid, {"strong": 0, "weak": 0})
386
- interp = interpret_cluster(
387
- cluster_id=cid,
388
- keyphrases=keyphrases.get(cid, []),
389
- rep_docs=rep_docs.get(cid, []),
390
- strong=sw["strong"],
391
- weak=sw["weak"],
392
- groq_client=client,
393
- mistral_key=mistral_key,
394
- gemini_key=gemini_key,
395
- )
396
- interpretations[cid] = interp
397
- logger.info("Cluster %d → %s [%s] (%d strong, %d weak)",
398
- cid, interp.final_label, interp.agreement,
399
- interp.strong_count, interp.weak_count)
400
-
401
- # Serialise
402
- records = [_convert(asdict(i)) for i in interpretations.values()]
403
- with open(output_json, "w") as f:
404
- json.dump(records, f, indent=2)
405
- df = pd.DataFrame(records)
406
- if not df.empty:
407
- for col in ["sheet1", "sheet2", "sheet3", "defence",
408
- "keyphrases", "grounding_check"]:
409
- if col in df.columns:
410
- df[col] = df[col].apply(str)
411
- df.to_csv(output_csv, index=False)
412
-
413
- return dict(interpretations=interpretations,
414
- json_path=output_json, csv_path=output_csv)
 
1
  """
2
+ agent.py — LangGraph-based topic analysis agent (§11).
3
+ 3-LLM Council for topic modelling, 4 sheets, triple-agreement tracking.
 
 
 
 
 
 
4
  """
 
5
  from __future__ import annotations
6
+ import json, logging, os, re, time
 
 
 
 
7
  from dataclasses import dataclass, field, asdict
8
+ from typing import TypedDict, Optional
9
+ from collections import Counter
10
+ import pandas as pd, numpy as np, requests
 
11
  from groq import Groq
12
+ from langgraph.graph import StateGraph, END
13
 
 
 
 
14
  logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
15
  logger = logging.getLogger(__name__)
16
 
17
+ GROQ_MODEL = "llama-3.1-8b-instant"
 
 
 
18
  MISTRAL_MODEL = "mistral-small-latest"
19
 
 
 
 
 
 
 
 
 
 
 
20
  # ---------------------------------------------------------------------------
21
+ # LangGraph state
22
  # ---------------------------------------------------------------------------
23
+ class PipelineState(TypedDict, total=False):
24
+ filepath: str
25
+ groq_key: str
26
+ mistral_key: str
27
+ gemini_key: str
28
+ n_trials: int
29
+ topic_data: dict
30
+ interpretations: dict
31
+ sheets: dict # {1: [...], 2: [...], 3: [...], 4: [...]}
32
+ agreement_rates: dict
33
+ mismatch_table: list
34
+ json_path: str
35
+ csv_path: str
36
+ error: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # ---------------------------------------------------------------------------
39
+ # API helpers
40
  # ---------------------------------------------------------------------------
41
+ def _parse(raw: str) -> dict:
42
+ raw = raw.strip().replace("```json","").replace("```","").strip()
43
+ s, e = raw.find("{"), raw.rfind("}")+1
44
+ if s != -1 and e > 0: raw = raw[s:e]
45
+ try: return json.loads(raw)
46
+ except: return {}
47
+
48
+ def _groq(client, prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  try:
50
+ r = client.chat.completions.create(model=GROQ_MODEL,
51
+ messages=[{"role":"user","content":prompt}], temperature=0.2, timeout=15)
52
+ return _parse(r.choices[0].message.content)
53
+ except Exception as e: logger.warning("Groq: %s",e); return {}
 
 
 
 
 
 
 
 
 
54
 
55
+ def _mistral(prompt, key):
56
+ if not key: return {}
 
 
 
 
57
  try:
58
+ r = requests.post("https://api.mistral.ai/v1/chat/completions",
59
+ headers={"Authorization":f"Bearer {key}","Content-Type":"application/json"},
60
+ json={"model":MISTRAL_MODEL,"messages":[{"role":"user","content":prompt}],
61
+ "temperature":0.2}, timeout=15)
62
+ return _parse(r.json()["choices"][0]["message"]["content"])
63
+ except Exception as e: logger.warning("Mistral: %s",e); return {}
64
+
65
+ def _gemini(prompt, key):
66
+ if not key: return {}
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
+ r = requests.post(
69
+ f"https://generativelanguage.googleapis.com/v1beta/models/"
70
+ f"gemini-2.5-flash:generateContent?key={key}",
71
+ headers={"Content-Type":"application/json"},
72
+ json={"contents":[{"parts":[{"text":prompt}]}],
73
+ "generationConfig":{"temperature":0.2}}, timeout=15)
74
+ d = r.json()
75
+ if "candidates" not in d: return {}
76
+ return _parse(d["candidates"][0]["content"]["parts"][0]["text"])
77
+ except Exception as e: logger.warning("Gemini: %s",e); return {}
78
 
79
  # ---------------------------------------------------------------------------
80
+ # Topic labelling prompt
81
  # ---------------------------------------------------------------------------
82
+ def _label_prompt(keyphrases, rep_docs):
83
+ kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
84
+ ab = " | ".join(a[:250] for a in rep_docs[:3])
85
+ return f"""You are a research topic classifier.
 
86
  A SPECTER-2 + HDBSCAN pipeline produced a topic cluster.
87
 
88
+ KEYPHRASES: {kp}
89
+ REPRESENTATIVE ABSTRACTS: {ab}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  Return ONLY valid JSON:
92
  {{
93
+ "label": "<5-8 word topic label>",
94
+ "description": "<one sentence description>",
95
+ "pacis_match": "<closest PAJAIS 2019 category, or NOVEL if none>",
96
+ "confidence": <0.0-1.0>
 
97
  }}"""
98
 
 
99
  # ---------------------------------------------------------------------------
100
+ # Defence prompt for disagreements
101
  # ---------------------------------------------------------------------------
102
+ def _defence_prompt(keyphrases, rep_docs, votes):
103
+ kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
104
+ v_str = "\n".join(f" LLM{i+1}: {v.get('label','?')}" for i,v in enumerate(votes))
105
+ return f"""Resolve this labelling disagreement.
106
+ KEYPHRASES: {kp}
107
+ Votes:\n{v_str}
108
+ Pick the best label or synthesise a better one.
109
+ Return ONLY JSON: {{"label":"...","description":"...","pacis_match":"...","confidence":0.0}}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # ---------------------------------------------------------------------------
112
+ # Grounding check
113
  # ---------------------------------------------------------------------------
114
+ def _grounding(label, keyphrases):
115
+ if not label or not keyphrases: return {"verdict":"FAIL","score":0}
116
+ lt = set(re.findall(r"\b[a-z]{3,}\b", label.lower()))
117
+ kt = set()
118
+ for k in keyphrases:
119
+ kt.update(re.findall(r"\b[a-z]{3,}\b", (k[0] if isinstance(k,tuple) else k).lower()))
120
+ noise = {"the","and","for","with","using","based","from","that","are","this"}
121
+ lt -= noise; kt -= noise
122
+ m = list(lt & kt)
123
+ return {"verdict":"PASS" if m else "FAIL", "score":len(m)/max(len(lt),1), "matched":m}
124
+
125
+ def _clean(s):
126
+ s = str(s or "").replace("\n"," ").strip()
127
+ return s[:60].rsplit(" ",1)[0] if len(s)>60 else s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # ---------------------------------------------------------------------------
130
+ # LangGraph node: run topic modelling
131
+ # ---------------------------------------------------------------------------
132
+ def embed_and_cluster(state: PipelineState) -> dict:
133
+ from tools import run_topic_modeling
134
+ try:
135
+ td = run_topic_modeling(state["filepath"], state.get("n_trials", 50))
136
+ return {"topic_data": td}
137
+ except Exception as e:
138
+ return {"error": str(e)}
139
 
140
  # ---------------------------------------------------------------------------
141
+ # LangGraph node: LLM Council — 4 sheets for topic modelling
142
  # ---------------------------------------------------------------------------
143
+ def llm_council(state: PipelineState) -> dict:
144
+ td = state["topic_data"]
145
+ if not td: return {"error": "No topic data"}
146
+ client = Groq(api_key=state["groq_key"], max_retries=0)
147
+ mk, gk = state["mistral_key"], state["gemini_key"]
148
+
149
+ sheets = {1:[], 2:[], 3:[], 4:[]} # 1=Groq, 2=Mistral, 3=Gemini, 4=Consolidated
150
+ interps = {}
151
+
152
+ for cid in sorted(td["keyphrases"].keys()):
153
+ kps = td["keyphrases"][cid]
154
+ rds = td["representative_docs"].get(cid, [])
155
+ sw = td["membership"].get(cid, {"strong":0,"weak":0})
156
+ prompt = _label_prompt(kps, rds)
157
+
158
+ s1 = _groq(client, prompt); time.sleep(1)
159
+ s2 = _mistral(prompt, mk); time.sleep(1)
160
+ s3 = _gemini(prompt, gk)
161
+ votes = [s1, s2, s3]
162
+
163
+ # Sheets 1-3
164
+ for si, (sheet_n, resp) in enumerate([(1,s1),(2,s2),(3,s3)]):
165
+ sheets[sheet_n].append({"cluster":cid, **{k:resp.get(k,"—")
166
+ for k in ["label","description","pacis_match","confidence"]}})
167
+
168
+ # Sheet 4: consolidate
169
+ valid = [v for v in votes if v and "label" in v]
170
+ labels_l = [_clean(v.get("label","")).lower() for v in valid]
171
+ counts = Counter(labels_l)
172
+
173
+ if any(c>=3 for c in counts.values()):
174
+ agreement = "Triple"
175
+ winner = max(counts, key=counts.get)
176
+ best = next(v for v in valid if _clean(v["label"]).lower()==winner)
177
+ elif any(c>=2 for c in counts.values()):
178
+ agreement = "Two"
179
+ winner = max(counts, key=counts.get)
180
+ best = next(v for v in valid if _clean(v["label"]).lower()==winner)
181
+ else:
182
+ agreement = "Single"
183
+ d = _groq(client, _defence_prompt(kps, rds, votes))
184
+ best = d if d and "label" in d else (valid[0] if valid else {})
185
+
186
+ label = _clean(best.get("label",""))
187
+ gc = _grounding(label, kps)
188
+ if gc["verdict"]=="FAIL" and valid:
189
+ label = _clean(valid[0].get("label",""))
190
+
191
+ cp = td.get("cluster_persistence",{}).get(cid, 0.0)
192
+ sheets[4].append({"cluster":cid, "label":label, "agreement":agreement,
193
+ "description":best.get("description",""),
194
+ "pacis_match":best.get("pacis_match",""),
195
+ "strong":sw["strong"], "weak":sw["weak"],
196
+ "persistence":round(cp,4), "grounding":gc["verdict"]})
197
+
198
+ interps[cid] = {"label":label, "agreement":agreement,
199
+ "strong":sw["strong"], "weak":sw["weak"],
200
+ "persistence":cp, "description":best.get("description",""),
201
+ "pacis_match":best.get("pacis_match",""),
202
+ "keyphrases":[k[0] if isinstance(k,tuple) else k for k in kps[:5]]}
203
+
204
+ logger.info("Cluster %d → %s [%s]", cid, label, agreement)
205
+
206
+ # Agreement rate on labels
207
+ total = len(sheets[4]) or 1
208
+ n_triple = sum(1 for r in sheets[4] if r.get("agreement")=="Triple")
209
+ n_two = sum(1 for r in sheets[4] if r.get("agreement")=="Two")
210
+ rates = {
211
+ "triple": round(n_triple / total * 100),
212
+ "two_or_more": round((n_triple + n_two) / total * 100),
213
+ "single": round((total - n_triple - n_two) / total * 100),
214
+ }
215
+
216
+ # Save outputs
217
+ records = sheets[4]
218
+ with open("topics.json","w") as f: json.dump(records, f, indent=2)
219
+ pd.DataFrame(records).to_csv("topics.csv", index=False)
220
+
221
+ return {"interpretations":interps, "sheets":sheets,
222
+ "agreement_rates":rates, "json_path":"topics.json", "csv_path":"topics.csv"}
223
 
224
+ # ---------------------------------------------------------------------------
225
+ # LangGraph node: build mismatch table
226
+ # ---------------------------------------------------------------------------
227
+ def build_mismatch(state: PipelineState) -> dict:
228
+ from tools import build_mismatch_table
229
+ td = state["topic_data"]
230
+ interps = state.get("interpretations", {})
231
+ labels_map = {cid: v["label"] for cid, v in interps.items()}
232
+ mt = build_mismatch_table(td["keyphrases"], labels_map)
233
+ return {"mismatch_table": mt}
234
 
235
  # ---------------------------------------------------------------------------
236
+ # Build the LangGraph
237
  # ---------------------------------------------------------------------------
238
+ def build_graph() -> StateGraph:
239
+ g = StateGraph(PipelineState)
240
+ g.add_node("embed_and_cluster", embed_and_cluster)
241
+ g.add_node("llm_council", llm_council)
242
+ g.add_node("build_mismatch", build_mismatch)
243
+ g.set_entry_point("embed_and_cluster")
244
+ g.add_edge("embed_and_cluster", "llm_council")
245
+ g.add_edge("llm_council", "build_mismatch")
246
+ g.add_edge("build_mismatch", END)
247
+ return g.compile()
248
+
249
+ # Compiled graph — importable
250
+ pipeline_graph = build_graph()
251
+
252
+ def run_pipeline(filepath, groq_key, mistral_key, gemini_key, n_trials=50):
253
+ """Convenience wrapper."""
254
+ result = pipeline_graph.invoke({
255
+ "filepath": filepath,
256
+ "groq_key": groq_key,
257
+ "mistral_key": mistral_key,
258
+ "gemini_key": gemini_key,
259
+ "n_trials": n_trials,
260
+ })
261
+ return result