anujjuna's picture
Update agent.py
12cbe9d verified
"""
agent.py — LangGraph-based topic analysis agent (§11).
3-LLM Council for topic modelling, 4 sheets, triple-agreement tracking.
"""
from __future__ import annotations
import json, logging, os, re, time
from dataclasses import dataclass, field, asdict
from typing import TypedDict, Optional
from collections import Counter
import pandas as pd, numpy as np, requests
from groq import Groq
from langgraph.graph import StateGraph, END
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
GROQ_MODEL = "llama-3.1-8b-instant"
MISTRAL_MODEL = "mistral-small-latest"
# ---------------------------------------------------------------------------
# LangGraph state
# ---------------------------------------------------------------------------
class PipelineState(TypedDict, total=False):
filepath: str
groq_key: str
mistral_key: str
gemini_key: str
n_trials: int
topic_data: dict
interpretations: dict
sheets: dict # {1: [...], 2: [...], 3: [...], 4: [...]}
agreement_rates: dict
mismatch_table: list
json_path: str
csv_path: str
error: str
# ---------------------------------------------------------------------------
# API helpers
# ---------------------------------------------------------------------------
def _parse(raw: str) -> dict:
raw = raw.strip().replace("```json","").replace("```","").strip()
s, e = raw.find("{"), raw.rfind("}")+1
if s != -1 and e > 0: raw = raw[s:e]
try: return json.loads(raw)
except: return {}
def _groq(client, prompt):
try:
r = client.chat.completions.create(model=GROQ_MODEL,
messages=[{"role":"user","content":prompt}], temperature=0.2, timeout=15)
return _parse(r.choices[0].message.content)
except Exception as e: logger.warning("Groq: %s",e); return {}
def _mistral(prompt, key):
if not key: return {}
try:
r = requests.post("https://api.mistral.ai/v1/chat/completions",
headers={"Authorization":f"Bearer {key}","Content-Type":"application/json"},
json={"model":MISTRAL_MODEL,"messages":[{"role":"user","content":prompt}],
"temperature":0.2}, timeout=15)
return _parse(r.json()["choices"][0]["message"]["content"])
except Exception as e: logger.warning("Mistral: %s",e); return {}
def _gemini(prompt, key):
if not key: return {}
model = "gemini-2.5-flash"
for attempt in range(3):
try:
r = requests.post(
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{model}:generateContent?key={key}",
headers={"Content-Type":"application/json"},
json={"contents":[{"parts":[{"text":prompt}]}],
"generationConfig":{"temperature":0.2}}, timeout=60)
d = r.json()
if "candidates" not in d:
err = d.get("error",{})
msg = err.get("message","") if isinstance(err,dict) else str(err)
if "quota" in msg.lower() or "rate" in msg.lower():
wait = min(40, 10 * (attempt + 1))
logger.warning("Gemini rate-limited, waiting %ds…", wait)
time.sleep(wait); continue
logger.warning("Gemini attempt %d: %s", attempt+1, msg)
return {}
return _parse(d["candidates"][0]["content"]["parts"][0]["text"])
except Exception as e:
logger.warning("Gemini attempt %d: %s", attempt+1, e)
time.sleep(5)
return {}
# ---------------------------------------------------------------------------
# Topic labelling prompt
# ---------------------------------------------------------------------------
def _label_prompt(keyphrases, rep_docs):
kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
ab = " | ".join(a[:250] for a in rep_docs[:3])
return f"""You are a research topic classifier.
A SPECTER-2 + HDBSCAN pipeline produced a topic cluster.
KEYPHRASES: {kp}
REPRESENTATIVE ABSTRACTS: {ab}
Return ONLY valid JSON:
{{
"label": "<5-8 word topic label>",
"description": "<one sentence description>",
"pacis_match": "<closest PAJAIS 2019 category, or NOVEL if none>",
"confidence": <0.0-1.0>
}}"""
# ---------------------------------------------------------------------------
# Defence prompt for disagreements
# ---------------------------------------------------------------------------
def _defence_prompt(keyphrases, rep_docs, votes):
kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
v_str = "\n".join(f" LLM{i+1}: {v.get('label','?')}" for i,v in enumerate(votes))
return f"""Resolve this labelling disagreement.
KEYPHRASES: {kp}
Votes:\n{v_str}
Pick the best label or synthesise a better one.
Return ONLY JSON: {{"label":"...","description":"...","pacis_match":"...","confidence":0.0}}"""
# ---------------------------------------------------------------------------
# Grounding check
# ---------------------------------------------------------------------------
def _grounding(label, keyphrases):
if not label or not keyphrases: return {"verdict":"FAIL","score":0}
lt = set(re.findall(r"\b[a-z]{3,}\b", label.lower()))
kt = set()
for k in keyphrases:
kt.update(re.findall(r"\b[a-z]{3,}\b", (k[0] if isinstance(k,tuple) else k).lower()))
noise = {"the","and","for","with","using","based","from","that","are","this"}
lt -= noise; kt -= noise
m = list(lt & kt)
return {"verdict":"PASS" if m else "FAIL", "score":len(m)/max(len(lt),1), "matched":m}
def _clean(s):
s = str(s or "").replace("\n"," ").strip()
return s[:60].rsplit(" ",1)[0] if len(s)>60 else s
# ---------------------------------------------------------------------------
# LangGraph node: run topic modelling
# ---------------------------------------------------------------------------
def embed_and_cluster(state: PipelineState) -> dict:
from tools import run_topic_modeling
try:
td = run_topic_modeling(state["filepath"], state.get("n_trials", 50))
return {"topic_data": td}
except Exception as e:
return {"error": str(e)}
# ---------------------------------------------------------------------------
# LangGraph node: LLM Council — 4 sheets for topic modelling
# ---------------------------------------------------------------------------
def llm_council(state: PipelineState) -> dict:
td = state["topic_data"]
if not td: return {"error": "No topic data"}
client = Groq(api_key=state["groq_key"], max_retries=0)
mk, gk = state["mistral_key"], state["gemini_key"]
sheets = {1:[], 2:[], 3:[], 4:[]} # 1=Groq, 2=Mistral, 3=Gemini, 4=Consolidated
interps = {}
for cid in sorted(td["keyphrases"].keys()):
kps = td["keyphrases"][cid]
rds = td["representative_docs"].get(cid, [])
sw = td["membership"].get(cid, {"strong":0,"weak":0})
prompt = _label_prompt(kps, rds)
s1 = _groq(client, prompt); time.sleep(1)
s2 = _mistral(prompt, mk); time.sleep(1)
s3 = _gemini(prompt, gk); time.sleep(4) # respect Gemini free-tier rate limit
votes = [s1, s2, s3]
# Sheets 1-3
for si, (sheet_n, resp) in enumerate([(1,s1),(2,s2),(3,s3)]):
sheets[sheet_n].append({"cluster":cid, **{k:resp.get(k,"—")
for k in ["label","description","pacis_match","confidence"]}})
# Sheet 4: consolidate
valid = [v for v in votes if v and "label" in v]
labels_l = [_clean(v.get("label","")).lower() for v in valid]
counts = Counter(labels_l)
if any(c>=3 for c in counts.values()):
agreement = "Triple"
winner = max(counts, key=counts.get)
best = next(v for v in valid if _clean(v["label"]).lower()==winner)
elif any(c>=2 for c in counts.values()):
agreement = "Two"
winner = max(counts, key=counts.get)
best = next(v for v in valid if _clean(v["label"]).lower()==winner)
else:
agreement = "Single"
d = _groq(client, _defence_prompt(kps, rds, votes))
best = d if d and "label" in d else (valid[0] if valid else {})
label = _clean(best.get("label",""))
gc = _grounding(label, kps)
if gc["verdict"]=="FAIL" and valid:
label = _clean(valid[0].get("label",""))
cp = td.get("cluster_persistence",{}).get(cid, 0.0)
sheets[4].append({"cluster":cid, "label":label, "agreement":agreement,
"description":best.get("description",""),
"pacis_match":best.get("pacis_match",""),
"strong":sw["strong"], "weak":sw["weak"],
"persistence":round(cp,4), "grounding":gc["verdict"]})
interps[cid] = {"label":label, "agreement":agreement,
"strong":sw["strong"], "weak":sw["weak"],
"persistence":cp, "description":best.get("description",""),
"pacis_match":best.get("pacis_match",""),
"keyphrases":[k[0] if isinstance(k,tuple) else k for k in kps[:5]]}
logger.info("Cluster %d → %s [%s]", cid, label, agreement)
# Agreement rate on labels
total = len(sheets[4]) or 1
n_triple = sum(1 for r in sheets[4] if r.get("agreement")=="Triple")
n_two = sum(1 for r in sheets[4] if r.get("agreement")=="Two")
rates = {
"triple": round(n_triple / total * 100),
"two_or_more": round((n_triple + n_two) / total * 100),
"single": round((total - n_triple - n_two) / total * 100),
}
# Save outputs — 4 separate sheet files
sheet_paths = {}
names = {1:"sheet1_groq",2:"sheet2_mistral",3:"sheet3_gemini",4:"sheet4_consolidated"}
for sn, name in names.items():
path = f"{name}.csv"
pd.DataFrame(sheets[sn]).to_csv(path, index=False)
sheet_paths[sn] = path
with open("topics.json","w") as f: json.dump(sheets[4], f, indent=2)
return {"interpretations":interps, "sheets":sheets,
"agreement_rates":rates, "sheet_paths":sheet_paths,
"json_path":"topics.json"}
# ---------------------------------------------------------------------------
# LangGraph node: build mismatch table
# ---------------------------------------------------------------------------
def build_mismatch(state: PipelineState) -> dict:
from tools import build_mismatch_table
td = state["topic_data"]
interps = state.get("interpretations", {})
labels_map = {cid: v["label"] for cid, v in interps.items()}
mt = build_mismatch_table(td["keyphrases"], labels_map)
return {"mismatch_table": mt}
# ---------------------------------------------------------------------------
# Build the LangGraph
# ---------------------------------------------------------------------------
def build_graph() -> StateGraph:
g = StateGraph(PipelineState)
g.add_node("embed_and_cluster", embed_and_cluster)
g.add_node("llm_council", llm_council)
g.add_node("build_mismatch", build_mismatch)
g.set_entry_point("embed_and_cluster")
g.add_edge("embed_and_cluster", "llm_council")
g.add_edge("llm_council", "build_mismatch")
g.add_edge("build_mismatch", END)
return g.compile()
# Compiled graph — importable
pipeline_graph = build_graph()
def run_pipeline(filepath, groq_key, mistral_key, gemini_key, n_trials=50):
"""Convenience wrapper."""
result = pipeline_graph.invoke({
"filepath": filepath,
"groq_key": groq_key,
"mistral_key": mistral_key,
"gemini_key": gemini_key,
"n_trials": n_trials,
})
return result