| """ |
| agent.py — LangGraph-based topic analysis agent (§11). |
| 3-LLM Council for topic modelling, 4 sheets, triple-agreement tracking. |
| """ |
| from __future__ import annotations |
| import json, logging, os, re, time |
| from dataclasses import dataclass, field, asdict |
| from typing import TypedDict, Optional |
| from collections import Counter |
| import pandas as pd, numpy as np, requests |
| from groq import Groq |
| from langgraph.graph import StateGraph, END |
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| GROQ_MODEL = "llama-3.1-8b-instant" |
| MISTRAL_MODEL = "mistral-small-latest" |
|
|
| |
| |
| |
| class PipelineState(TypedDict, total=False): |
| filepath: str |
| groq_key: str |
| mistral_key: str |
| gemini_key: str |
| n_trials: int |
| topic_data: dict |
| interpretations: dict |
| sheets: dict |
| agreement_rates: dict |
| mismatch_table: list |
| json_path: str |
| csv_path: str |
| error: str |
|
|
| |
| |
| |
| def _parse(raw: str) -> dict: |
| raw = raw.strip().replace("```json","").replace("```","").strip() |
| s, e = raw.find("{"), raw.rfind("}")+1 |
| if s != -1 and e > 0: raw = raw[s:e] |
| try: return json.loads(raw) |
| except: return {} |
|
|
| def _groq(client, prompt): |
| try: |
| r = client.chat.completions.create(model=GROQ_MODEL, |
| messages=[{"role":"user","content":prompt}], temperature=0.2, timeout=15) |
| return _parse(r.choices[0].message.content) |
| except Exception as e: logger.warning("Groq: %s",e); return {} |
|
|
| def _mistral(prompt, key): |
| if not key: return {} |
| try: |
| r = requests.post("https://api.mistral.ai/v1/chat/completions", |
| headers={"Authorization":f"Bearer {key}","Content-Type":"application/json"}, |
| json={"model":MISTRAL_MODEL,"messages":[{"role":"user","content":prompt}], |
| "temperature":0.2}, timeout=15) |
| return _parse(r.json()["choices"][0]["message"]["content"]) |
| except Exception as e: logger.warning("Mistral: %s",e); return {} |
|
|
| def _gemini(prompt, key): |
| if not key: return {} |
| model = "gemini-2.5-flash" |
| for attempt in range(3): |
| try: |
| r = requests.post( |
| f"https://generativelanguage.googleapis.com/v1beta/models/" |
| f"{model}:generateContent?key={key}", |
| headers={"Content-Type":"application/json"}, |
| json={"contents":[{"parts":[{"text":prompt}]}], |
| "generationConfig":{"temperature":0.2}}, timeout=60) |
| d = r.json() |
| if "candidates" not in d: |
| err = d.get("error",{}) |
| msg = err.get("message","") if isinstance(err,dict) else str(err) |
| if "quota" in msg.lower() or "rate" in msg.lower(): |
| wait = min(40, 10 * (attempt + 1)) |
| logger.warning("Gemini rate-limited, waiting %ds…", wait) |
| time.sleep(wait); continue |
| logger.warning("Gemini attempt %d: %s", attempt+1, msg) |
| return {} |
| return _parse(d["candidates"][0]["content"]["parts"][0]["text"]) |
| except Exception as e: |
| logger.warning("Gemini attempt %d: %s", attempt+1, e) |
| time.sleep(5) |
| return {} |
|
|
| |
| |
| |
| def _label_prompt(keyphrases, rep_docs): |
| kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5]) |
| ab = " | ".join(a[:250] for a in rep_docs[:3]) |
| return f"""You are a research topic classifier. |
| A SPECTER-2 + HDBSCAN pipeline produced a topic cluster. |
| |
| KEYPHRASES: {kp} |
| REPRESENTATIVE ABSTRACTS: {ab} |
| |
| Return ONLY valid JSON: |
| {{ |
| "label": "<5-8 word topic label>", |
| "description": "<one sentence description>", |
| "pacis_match": "<closest PAJAIS 2019 category, or NOVEL if none>", |
| "confidence": <0.0-1.0> |
| }}""" |
|
|
| |
| |
| |
| def _defence_prompt(keyphrases, rep_docs, votes): |
| kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5]) |
| v_str = "\n".join(f" LLM{i+1}: {v.get('label','?')}" for i,v in enumerate(votes)) |
| return f"""Resolve this labelling disagreement. |
| KEYPHRASES: {kp} |
| Votes:\n{v_str} |
| Pick the best label or synthesise a better one. |
| Return ONLY JSON: {{"label":"...","description":"...","pacis_match":"...","confidence":0.0}}""" |
|
|
| |
| |
| |
| def _grounding(label, keyphrases): |
| if not label or not keyphrases: return {"verdict":"FAIL","score":0} |
| lt = set(re.findall(r"\b[a-z]{3,}\b", label.lower())) |
| kt = set() |
| for k in keyphrases: |
| kt.update(re.findall(r"\b[a-z]{3,}\b", (k[0] if isinstance(k,tuple) else k).lower())) |
| noise = {"the","and","for","with","using","based","from","that","are","this"} |
| lt -= noise; kt -= noise |
| m = list(lt & kt) |
| return {"verdict":"PASS" if m else "FAIL", "score":len(m)/max(len(lt),1), "matched":m} |
|
|
| def _clean(s): |
| s = str(s or "").replace("\n"," ").strip() |
| return s[:60].rsplit(" ",1)[0] if len(s)>60 else s |
|
|
| |
| |
| |
| def embed_and_cluster(state: PipelineState) -> dict: |
| from tools import run_topic_modeling |
| try: |
| td = run_topic_modeling(state["filepath"], state.get("n_trials", 50)) |
| return {"topic_data": td} |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| |
| |
| |
| def llm_council(state: PipelineState) -> dict: |
| td = state["topic_data"] |
| if not td: return {"error": "No topic data"} |
| client = Groq(api_key=state["groq_key"], max_retries=0) |
| mk, gk = state["mistral_key"], state["gemini_key"] |
|
|
| sheets = {1:[], 2:[], 3:[], 4:[]} |
| interps = {} |
|
|
| for cid in sorted(td["keyphrases"].keys()): |
| kps = td["keyphrases"][cid] |
| rds = td["representative_docs"].get(cid, []) |
| sw = td["membership"].get(cid, {"strong":0,"weak":0}) |
| prompt = _label_prompt(kps, rds) |
|
|
| s1 = _groq(client, prompt); time.sleep(1) |
| s2 = _mistral(prompt, mk); time.sleep(1) |
| s3 = _gemini(prompt, gk); time.sleep(4) |
| votes = [s1, s2, s3] |
|
|
| |
| for si, (sheet_n, resp) in enumerate([(1,s1),(2,s2),(3,s3)]): |
| sheets[sheet_n].append({"cluster":cid, **{k:resp.get(k,"—") |
| for k in ["label","description","pacis_match","confidence"]}}) |
|
|
| |
| valid = [v for v in votes if v and "label" in v] |
| labels_l = [_clean(v.get("label","")).lower() for v in valid] |
| counts = Counter(labels_l) |
|
|
| if any(c>=3 for c in counts.values()): |
| agreement = "Triple" |
| winner = max(counts, key=counts.get) |
| best = next(v for v in valid if _clean(v["label"]).lower()==winner) |
| elif any(c>=2 for c in counts.values()): |
| agreement = "Two" |
| winner = max(counts, key=counts.get) |
| best = next(v for v in valid if _clean(v["label"]).lower()==winner) |
| else: |
| agreement = "Single" |
| d = _groq(client, _defence_prompt(kps, rds, votes)) |
| best = d if d and "label" in d else (valid[0] if valid else {}) |
|
|
| label = _clean(best.get("label","")) |
| gc = _grounding(label, kps) |
| if gc["verdict"]=="FAIL" and valid: |
| label = _clean(valid[0].get("label","")) |
|
|
| cp = td.get("cluster_persistence",{}).get(cid, 0.0) |
| sheets[4].append({"cluster":cid, "label":label, "agreement":agreement, |
| "description":best.get("description",""), |
| "pacis_match":best.get("pacis_match",""), |
| "strong":sw["strong"], "weak":sw["weak"], |
| "persistence":round(cp,4), "grounding":gc["verdict"]}) |
|
|
| interps[cid] = {"label":label, "agreement":agreement, |
| "strong":sw["strong"], "weak":sw["weak"], |
| "persistence":cp, "description":best.get("description",""), |
| "pacis_match":best.get("pacis_match",""), |
| "keyphrases":[k[0] if isinstance(k,tuple) else k for k in kps[:5]]} |
|
|
| logger.info("Cluster %d → %s [%s]", cid, label, agreement) |
|
|
| |
| total = len(sheets[4]) or 1 |
| n_triple = sum(1 for r in sheets[4] if r.get("agreement")=="Triple") |
| n_two = sum(1 for r in sheets[4] if r.get("agreement")=="Two") |
| rates = { |
| "triple": round(n_triple / total * 100), |
| "two_or_more": round((n_triple + n_two) / total * 100), |
| "single": round((total - n_triple - n_two) / total * 100), |
| } |
|
|
| |
| sheet_paths = {} |
| names = {1:"sheet1_groq",2:"sheet2_mistral",3:"sheet3_gemini",4:"sheet4_consolidated"} |
| for sn, name in names.items(): |
| path = f"{name}.csv" |
| pd.DataFrame(sheets[sn]).to_csv(path, index=False) |
| sheet_paths[sn] = path |
| with open("topics.json","w") as f: json.dump(sheets[4], f, indent=2) |
|
|
| return {"interpretations":interps, "sheets":sheets, |
| "agreement_rates":rates, "sheet_paths":sheet_paths, |
| "json_path":"topics.json"} |
|
|
| |
| |
| |
| def build_mismatch(state: PipelineState) -> dict: |
| from tools import build_mismatch_table |
| td = state["topic_data"] |
| interps = state.get("interpretations", {}) |
| labels_map = {cid: v["label"] for cid, v in interps.items()} |
| mt = build_mismatch_table(td["keyphrases"], labels_map) |
| return {"mismatch_table": mt} |
|
|
| |
| |
| |
| def build_graph() -> StateGraph: |
| g = StateGraph(PipelineState) |
| g.add_node("embed_and_cluster", embed_and_cluster) |
| g.add_node("llm_council", llm_council) |
| g.add_node("build_mismatch", build_mismatch) |
| g.set_entry_point("embed_and_cluster") |
| g.add_edge("embed_and_cluster", "llm_council") |
| g.add_edge("llm_council", "build_mismatch") |
| g.add_edge("build_mismatch", END) |
| return g.compile() |
|
|
| |
| pipeline_graph = build_graph() |
|
|
| def run_pipeline(filepath, groq_key, mistral_key, gemini_key, n_trials=50): |
| """Convenience wrapper.""" |
| result = pipeline_graph.invoke({ |
| "filepath": filepath, |
| "groq_key": groq_key, |
| "mistral_key": mistral_key, |
| "gemini_key": gemini_key, |
| "n_trials": n_trials, |
| }) |
| return result |
|
|