""" agent.py — LangGraph-based topic analysis agent (§11). 3-LLM Council for topic modelling, 4 sheets, triple-agreement tracking. """ from __future__ import annotations import json, logging, os, re, time from dataclasses import dataclass, field, asdict from typing import TypedDict, Optional from collections import Counter import pandas as pd, numpy as np, requests from groq import Groq from langgraph.graph import StateGraph, END logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s") logger = logging.getLogger(__name__) GROQ_MODEL = "llama-3.1-8b-instant" MISTRAL_MODEL = "mistral-small-latest" # --------------------------------------------------------------------------- # LangGraph state # --------------------------------------------------------------------------- class PipelineState(TypedDict, total=False): filepath: str groq_key: str mistral_key: str gemini_key: str n_trials: int topic_data: dict interpretations: dict sheets: dict # {1: [...], 2: [...], 3: [...], 4: [...]} agreement_rates: dict mismatch_table: list json_path: str csv_path: str error: str # --------------------------------------------------------------------------- # API helpers # --------------------------------------------------------------------------- def _parse(raw: str) -> dict: raw = raw.strip().replace("```json","").replace("```","").strip() s, e = raw.find("{"), raw.rfind("}")+1 if s != -1 and e > 0: raw = raw[s:e] try: return json.loads(raw) except: return {} def _groq(client, prompt): try: r = client.chat.completions.create(model=GROQ_MODEL, messages=[{"role":"user","content":prompt}], temperature=0.2, timeout=15) return _parse(r.choices[0].message.content) except Exception as e: logger.warning("Groq: %s",e); return {} def _mistral(prompt, key): if not key: return {} try: r = requests.post("https://api.mistral.ai/v1/chat/completions", headers={"Authorization":f"Bearer {key}","Content-Type":"application/json"}, json={"model":MISTRAL_MODEL,"messages":[{"role":"user","content":prompt}], "temperature":0.2}, timeout=15) return _parse(r.json()["choices"][0]["message"]["content"]) except Exception as e: logger.warning("Mistral: %s",e); return {} def _gemini(prompt, key): if not key: return {} model = "gemini-2.5-flash" for attempt in range(3): try: r = requests.post( f"https://generativelanguage.googleapis.com/v1beta/models/" f"{model}:generateContent?key={key}", headers={"Content-Type":"application/json"}, json={"contents":[{"parts":[{"text":prompt}]}], "generationConfig":{"temperature":0.2}}, timeout=60) d = r.json() if "candidates" not in d: err = d.get("error",{}) msg = err.get("message","") if isinstance(err,dict) else str(err) if "quota" in msg.lower() or "rate" in msg.lower(): wait = min(40, 10 * (attempt + 1)) logger.warning("Gemini rate-limited, waiting %ds…", wait) time.sleep(wait); continue logger.warning("Gemini attempt %d: %s", attempt+1, msg) return {} return _parse(d["candidates"][0]["content"]["parts"][0]["text"]) except Exception as e: logger.warning("Gemini attempt %d: %s", attempt+1, e) time.sleep(5) return {} # --------------------------------------------------------------------------- # Topic labelling prompt # --------------------------------------------------------------------------- def _label_prompt(keyphrases, rep_docs): kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5]) ab = " | ".join(a[:250] for a in rep_docs[:3]) return f"""You are a research topic classifier. A SPECTER-2 + HDBSCAN pipeline produced a topic cluster. KEYPHRASES: {kp} REPRESENTATIVE ABSTRACTS: {ab} Return ONLY valid JSON: {{ "label": "<5-8 word topic label>", "description": "", "pacis_match": "", "confidence": <0.0-1.0> }}""" # --------------------------------------------------------------------------- # Defence prompt for disagreements # --------------------------------------------------------------------------- def _defence_prompt(keyphrases, rep_docs, votes): kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5]) v_str = "\n".join(f" LLM{i+1}: {v.get('label','?')}" for i,v in enumerate(votes)) return f"""Resolve this labelling disagreement. KEYPHRASES: {kp} Votes:\n{v_str} Pick the best label or synthesise a better one. Return ONLY JSON: {{"label":"...","description":"...","pacis_match":"...","confidence":0.0}}""" # --------------------------------------------------------------------------- # Grounding check # --------------------------------------------------------------------------- def _grounding(label, keyphrases): if not label or not keyphrases: return {"verdict":"FAIL","score":0} lt = set(re.findall(r"\b[a-z]{3,}\b", label.lower())) kt = set() for k in keyphrases: kt.update(re.findall(r"\b[a-z]{3,}\b", (k[0] if isinstance(k,tuple) else k).lower())) noise = {"the","and","for","with","using","based","from","that","are","this"} lt -= noise; kt -= noise m = list(lt & kt) return {"verdict":"PASS" if m else "FAIL", "score":len(m)/max(len(lt),1), "matched":m} def _clean(s): s = str(s or "").replace("\n"," ").strip() return s[:60].rsplit(" ",1)[0] if len(s)>60 else s # --------------------------------------------------------------------------- # LangGraph node: run topic modelling # --------------------------------------------------------------------------- def embed_and_cluster(state: PipelineState) -> dict: from tools import run_topic_modeling try: td = run_topic_modeling(state["filepath"], state.get("n_trials", 50)) return {"topic_data": td} except Exception as e: return {"error": str(e)} # --------------------------------------------------------------------------- # LangGraph node: LLM Council — 4 sheets for topic modelling # --------------------------------------------------------------------------- def llm_council(state: PipelineState) -> dict: td = state["topic_data"] if not td: return {"error": "No topic data"} client = Groq(api_key=state["groq_key"], max_retries=0) mk, gk = state["mistral_key"], state["gemini_key"] sheets = {1:[], 2:[], 3:[], 4:[]} # 1=Groq, 2=Mistral, 3=Gemini, 4=Consolidated interps = {} for cid in sorted(td["keyphrases"].keys()): kps = td["keyphrases"][cid] rds = td["representative_docs"].get(cid, []) sw = td["membership"].get(cid, {"strong":0,"weak":0}) prompt = _label_prompt(kps, rds) s1 = _groq(client, prompt); time.sleep(1) s2 = _mistral(prompt, mk); time.sleep(1) s3 = _gemini(prompt, gk); time.sleep(4) # respect Gemini free-tier rate limit votes = [s1, s2, s3] # Sheets 1-3 for si, (sheet_n, resp) in enumerate([(1,s1),(2,s2),(3,s3)]): sheets[sheet_n].append({"cluster":cid, **{k:resp.get(k,"—") for k in ["label","description","pacis_match","confidence"]}}) # Sheet 4: consolidate valid = [v for v in votes if v and "label" in v] labels_l = [_clean(v.get("label","")).lower() for v in valid] counts = Counter(labels_l) if any(c>=3 for c in counts.values()): agreement = "Triple" winner = max(counts, key=counts.get) best = next(v for v in valid if _clean(v["label"]).lower()==winner) elif any(c>=2 for c in counts.values()): agreement = "Two" winner = max(counts, key=counts.get) best = next(v for v in valid if _clean(v["label"]).lower()==winner) else: agreement = "Single" d = _groq(client, _defence_prompt(kps, rds, votes)) best = d if d and "label" in d else (valid[0] if valid else {}) label = _clean(best.get("label","")) gc = _grounding(label, kps) if gc["verdict"]=="FAIL" and valid: label = _clean(valid[0].get("label","")) cp = td.get("cluster_persistence",{}).get(cid, 0.0) sheets[4].append({"cluster":cid, "label":label, "agreement":agreement, "description":best.get("description",""), "pacis_match":best.get("pacis_match",""), "strong":sw["strong"], "weak":sw["weak"], "persistence":round(cp,4), "grounding":gc["verdict"]}) interps[cid] = {"label":label, "agreement":agreement, "strong":sw["strong"], "weak":sw["weak"], "persistence":cp, "description":best.get("description",""), "pacis_match":best.get("pacis_match",""), "keyphrases":[k[0] if isinstance(k,tuple) else k for k in kps[:5]]} logger.info("Cluster %d → %s [%s]", cid, label, agreement) # Agreement rate on labels total = len(sheets[4]) or 1 n_triple = sum(1 for r in sheets[4] if r.get("agreement")=="Triple") n_two = sum(1 for r in sheets[4] if r.get("agreement")=="Two") rates = { "triple": round(n_triple / total * 100), "two_or_more": round((n_triple + n_two) / total * 100), "single": round((total - n_triple - n_two) / total * 100), } # Save outputs — 4 separate sheet files sheet_paths = {} names = {1:"sheet1_groq",2:"sheet2_mistral",3:"sheet3_gemini",4:"sheet4_consolidated"} for sn, name in names.items(): path = f"{name}.csv" pd.DataFrame(sheets[sn]).to_csv(path, index=False) sheet_paths[sn] = path with open("topics.json","w") as f: json.dump(sheets[4], f, indent=2) return {"interpretations":interps, "sheets":sheets, "agreement_rates":rates, "sheet_paths":sheet_paths, "json_path":"topics.json"} # --------------------------------------------------------------------------- # LangGraph node: build mismatch table # --------------------------------------------------------------------------- def build_mismatch(state: PipelineState) -> dict: from tools import build_mismatch_table td = state["topic_data"] interps = state.get("interpretations", {}) labels_map = {cid: v["label"] for cid, v in interps.items()} mt = build_mismatch_table(td["keyphrases"], labels_map) return {"mismatch_table": mt} # --------------------------------------------------------------------------- # Build the LangGraph # --------------------------------------------------------------------------- def build_graph() -> StateGraph: g = StateGraph(PipelineState) g.add_node("embed_and_cluster", embed_and_cluster) g.add_node("llm_council", llm_council) g.add_node("build_mismatch", build_mismatch) g.set_entry_point("embed_and_cluster") g.add_edge("embed_and_cluster", "llm_council") g.add_edge("llm_council", "build_mismatch") g.add_edge("build_mismatch", END) return g.compile() # Compiled graph — importable pipeline_graph = build_graph() def run_pipeline(filepath, groq_key, mistral_key, gemini_key, n_trials=50): """Convenience wrapper.""" result = pipeline_graph.invoke({ "filepath": filepath, "groq_key": groq_key, "mistral_key": mistral_key, "gemini_key": gemini_key, "n_trials": n_trials, }) return result