BERTopic_AG_final / agent.py
BHAVIKBANKER's picture
Update agent.py
80d15c2 verified
"""
agent.py β€” LangGraph-based topic analysis agent (Β§11).
Original 3-LLM Council for topic modelling is UNCHANGED.
NEW nodes appended:
- load_methodology_corpus : load methodology CSV, detect journal per paper
- embed_methodology_vectors : SPECTER-2 embed methodology text (separate vector space)
- extract_comp_techniques : 3-LLM council (regex β†’ Groq β†’ Mistral β†’ Gemini β†’ consolidate)
- build_journal_crosstab : technique Γ— journal cross-tabulation with percentages
- optimize_technique_labels : improvement / hallucination critique on consolidated techniques
"""
from __future__ import annotations
import json, logging, os, re, time
from typing import TypedDict
from collections import Counter, defaultdict
import pandas as pd, numpy as np, requests
from groq import Groq
from langgraph.graph import StateGraph, END
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
GROQ_MODEL = "llama-3.1-8b-instant"
MISTRAL_MODEL = "mistral-small-latest"
# ============================================================================
# REGEX BANKS (used in both cluster methodology AND methodology-CSV pipeline)
# ============================================================================
METHODOLOGY_PATTERNS = {
"Survey / Systematic Review": re.compile(
r"\b(survey|systematic\s+review|literature\s+review|bibliometric|scoping\s+review|meta.?analysis)\b", re.I),
"Experiment / Lab Study": re.compile(
r"\b(experiment(al)?|laboratory|lab\s+study|controlled\s+trial|randomized|RCT)\b", re.I),
"Case Study": re.compile(
r"\b(case\s+study|case\s+analysis|real.world\s+case|industry\s+case)\b", re.I),
"Simulation / Agent-Based": re.compile(
r"\b(simulat(ion|ed)|agent.based|discrete.event|monte\s+carlo|stochastic)\b", re.I),
"Empirical / Field Study": re.compile(
r"\b(empirical|field\s+study|observational|longitudinal|cross.sectional|cohort)\b", re.I),
"Design Science / Prototype": re.compile(
r"\b(design\s+science|prototype|proof.of.concept|artifact|system\s+design|framework\s+design)\b", re.I),
"Theoretical / Conceptual": re.compile(
r"\b(theoretical|conceptual\s+framework|analytical\s+model|formal\s+model|theorem|proposition)\b", re.I),
"Action Research": re.compile(
r"\b(action\s+research|participatory|co.design|practitioner)\b", re.I),
"Mixed Methods": re.compile(
r"\b(mixed\s+method|qualitative.+quantitative|quantitative.+qualitative|triangulat)\b", re.I),
}
TECHNIQUE_PATTERNS = {
"Machine Learning": re.compile(
r"\b(machine\s+learning|ML|random\s+forest|gradient\s+boost|XGBoost|SVM|support\s+vector|decision\s+tree|k.?nearest|naive\s+bayes)\b", re.I),
"Deep Learning / Neural Net": re.compile(
r"\b(deep\s+learning|neural\s+network|CNN|RNN|LSTM|GRU|transformer|attention\s+mechanism|BERT|GPT)\b", re.I),
"NLP / Text Mining": re.compile(
r"\b(natural\s+language\s+processing|NLP|text\s+mining|topic\s+model|LDA|word2vec|embedding|sentiment\s+analys)\b", re.I),
"Optimisation": re.compile(
r"\b(optimis(ation|e|ing)|genetic\s+algorithm|evolutionary|particle\s+swarm|simulated\s+anneal|ant\s+colony|bayesian\s+optim)\b", re.I),
"Statistical Analysis": re.compile(
r"\b(regression|ANOVA|t.test|chi.square|correlation|factor\s+analysis|structural\s+equation|SEM|PLS|logistic\s+regression)\b", re.I),
"Clustering / Dimensionality": re.compile(
r"\b(cluster(ing)?|HDBSCAN|DBSCAN|k.?means|hierarchical|UMAP|t.?SNE|PCA|dimensionality\s+reduction)\b", re.I),
"Graph / Network Analysis": re.compile(
r"\b(graph\s+(neural|network|analysis)|GNN|knowledge\s+graph|network\s+analys|social\s+network|link\s+prediction)\b", re.I),
"Computer Vision": re.compile(
r"\b(computer\s+vision|image\s+(processing|recognition|classification)|object\s+detection|segmentation|GAN)\b", re.I),
"Fuzzy / Rule-Based Systems": re.compile(
r"\b(fuzzy\s+logic|rule.based|expert\s+system|ontology|knowledge\s+base|inference\s+engine)\b", re.I),
"Blockchain / Distributed": re.compile(
r"\b(blockchain|distributed\s+ledger|smart\s+contract|consensus|decentrali[sz]ed)\b", re.I),
"Reinforcement Learning": re.compile(
r"\b(reinforcement\s+learning|Q.learning|policy\s+gradient|reward\s+function|Markov\s+decision)\b", re.I),
"Cloud / Big Data": re.compile(
r"\b(cloud\s+computing|Hadoop|Spark|MapReduce|big\s+data|distributed\s+computing|edge\s+computing)\b", re.I),
"Structural Equation Modelling": re.compile(
r"\b(structural\s+equation|SEM|PLS.SEM|covariance.based|CB.SEM|partial\s+least\s+squares)\b", re.I),
"Time Series / VAR": re.compile(
r"\b(time\s+series|VAR\b|vector\s+auto.?regression|VARX|ARIMA|impulse\s+response|Granger)\b", re.I),
"Content Analysis / Coding": re.compile(
r"\b(content\s+analysis|coding\s+scheme|thematic\s+analys|grounded\s+theory|open\s+coding|axial\s+coding)\b", re.I),
}
ORIENTATION_PATTERNS = {
"empirical": re.compile(r"\b(empirical|experiment(al)?|field\s+study|data.driven|survey|dataset)\b", re.I),
"theoretical": re.compile(r"\b(theoretical|conceptual|formal\s+model|analytical|theorem|proposition)\b", re.I),
"mixed": re.compile(r"\b(mixed\s+method|qualitative.+quantitative|both|triangulat)\b", re.I),
}
# Journal detection patterns applied to DOI + title
JOURNAL_PATTERNS = {
"MISQ": re.compile(
r"(misq|mis\s*quarterly|10\.25300|10\.2307/[0-9]{8}|MIS\s+Quarterly)", re.I),
"JAIS": re.compile(
r"(jais|10\.17705/1jais|journal.*association.*information\s+systems)", re.I),
"ISR": re.compile(
r"(10\.1287/isre|\bisr\b|information\s+systems\s+research)", re.I),
"JMIS": re.compile(
r"(10\.1080/07421222|jmis|journal.*management.*information\s+systems)", re.I),
"PAJAIS": re.compile(
r"(pajais|pacific.*asia.*information|10\.17705/2asfp)", re.I),
"ECIS": re.compile(
r"(ecis|european.*conference.*information\s+systems)", re.I),
"ICIS": re.compile(
r"(icis|international.*conference.*information\s+systems)", re.I),
}
# ============================================================================
# SHARED REGEX HELPERS
# ============================================================================
def _regex_scan(docs: list[str]) -> dict:
"""Run pattern banks against docs. Returns hit dicts with exact match spans."""
method_hits = defaultdict(list)
technique_hits = defaultdict(list)
orientation_counts = {"empirical": 0, "theoretical": 0, "mixed": 0}
for doc_idx, doc in enumerate(docs):
for label, pat in METHODOLOGY_PATTERNS.items():
for m in pat.finditer(doc):
method_hits[label].append({
"doc": doc_idx + 1, "match": m.group(0),
"span": [m.start(), m.end()]})
for label, pat in TECHNIQUE_PATTERNS.items():
for m in pat.finditer(doc):
technique_hits[label].append({
"doc": doc_idx + 1, "match": m.group(0),
"span": [m.start(), m.end()]})
for orient, pat in ORIENTATION_PATTERNS.items():
if pat.search(doc):
orientation_counts[orient] += 1
total_orient = sum(orientation_counts.values()) or 1
return {
"methods": {k: v for k, v in method_hits.items() if v},
"techniques": {k: v for k, v in technique_hits.items() if v},
"orientation": {
"empirical_pct": round(orientation_counts["empirical"] / total_orient * 100),
"theoretical_pct": round(orientation_counts["theoretical"] / total_orient * 100),
"mixed_pct": round(orientation_counts["mixed"] / total_orient * 100),
},
"patterns_applied": {
"methodology": list(METHODOLOGY_PATTERNS.keys()),
"technique": list(TECHNIQUE_PATTERNS.keys()),
},
}
def _regex_summary(scan: dict) -> str:
"""Human-readable regex evidence injected into LLM prompts."""
lines = []
if scan["methods"]:
lines.append("REGEX-DETECTED METHODOLOGIES:")
for k, hits in scan["methods"].items():
unique = list(dict.fromkeys(h["match"] for h in hits))[:3]
papers = sorted({h["doc"] for h in hits})
lines.append(f" β€’ {k} β€” matched: {unique} (papers: {papers})")
if scan["techniques"]:
lines.append("REGEX-DETECTED TECHNIQUES:")
for k, hits in scan["techniques"].items():
unique = list(dict.fromkeys(h["match"] for h in hits))[:3]
papers = sorted({h["doc"] for h in hits})
lines.append(f" β€’ {k} β€” matched: {unique} (papers: {papers})")
return "\n".join(lines) or "No regex hits found β€” rely on methodology text alone."
def _detect_journal(doi: str, title: str) -> str:
"""Detect journal from DOI + title using JOURNAL_PATTERNS. Returns 'Other' if unknown."""
text = f"{doi or ''} {title or ''}"
for journal, pat in JOURNAL_PATTERNS.items():
if pat.search(text):
return journal
return "MISQ" # methodology CSV default β€” override downstream if needed
# ============================================================================
# LANGGRAPH STATE
# ============================================================================
class PipelineState(TypedDict, total=False):
# ── original fields (DO NOT CHANGE) ──────────────────────────────────────
filepath: str
groq_key: str
mistral_key: str
gemini_key: str
n_trials: int
n_optimize: int
topic_data: dict
interpretations: dict
sheets: dict
agreement_rates: dict
mismatch_table: list
methodology_data: dict
top_papers: dict
refinement_log: list
json_path: str
error: str
# ── new fields for methodology-CSV pipeline ───────────────────────────────
methodology_filepath: str # uploaded methodology CSV path
methodology_papers: list # [{title, doi, methodology, journal, paper_idx}]
methodology_embeddings: list # SPECTER-2 embeddings (separate vector space)
comp_technique_sheets: dict # {1:Groq, 2:Mistral, 3:Gemini, 4:Consolidated}
journal_crosstab: dict # {journal: {technique: pct}}
technique_opt_log: list # improvement suggestions from optimizer
# ============================================================================
# API HELPERS (unchanged)
# ============================================================================
def _parse(raw: str) -> dict:
raw = raw.strip().replace("```json","").replace("```","").strip()
s, e = raw.find("{"), raw.rfind("}")+1
if s != -1 and e > 0: raw = raw[s:e]
try: return json.loads(raw)
except: return {}
def _groq(client, prompt):
try:
r = client.chat.completions.create(model=GROQ_MODEL,
messages=[{"role":"user","content":prompt}], temperature=0, timeout=30)
return _parse(r.choices[0].message.content)
except Exception as e: logger.warning("Groq: %s", e); return {}
def _mistral(prompt, key):
if not key: return {}
try:
r = requests.post("https://api.mistral.ai/v1/chat/completions",
headers={"Authorization":f"Bearer {key}","Content-Type":"application/json"},
json={"model":MISTRAL_MODEL,"messages":[{"role":"user","content":prompt}],
"temperature":0}, timeout=30)
return _parse(r.json()["choices"][0]["message"]["content"])
except Exception as e: logger.warning("Mistral: %s", e); return {}
def _gemini(prompt, key):
if not key: return {}
model = "gemini-2.5-flash"
for attempt in range(3):
try:
r = requests.post(
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{model}:generateContent?key={key}",
headers={"Content-Type":"application/json"},
json={"contents":[{"parts":[{"text":prompt}]}],
"generationConfig":{"temperature":0}}, timeout=60)
d = r.json()
if "candidates" not in d:
err = d.get("error",{})
msg = err.get("message","") if isinstance(err,dict) else str(err)
if "quota" in msg.lower() or "rate" in msg.lower():
wait = min(40, 10*(attempt+1))
logger.warning("Gemini rate-limited, waiting %ds…", wait)
time.sleep(wait); continue
logger.warning("Gemini attempt %d: %s", attempt+1, msg); return {}
return _parse(d["candidates"][0]["content"]["parts"][0]["text"])
except Exception as e:
logger.warning("Gemini attempt %d: %s", attempt+1, e); time.sleep(5)
return {}
# ============================================================================
# ORIGINAL PROMPTS (unchanged)
# ============================================================================
def _label_prompt(keyphrases, rep_docs):
kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
ab = " | ".join(a[:250] for a in rep_docs[:3])
return f"""You are a research topic classifier.
A SPECTER-2 + HDBSCAN pipeline produced a topic cluster.
KEYPHRASES: {kp}
REPRESENTATIVE ABSTRACTS: {ab}
Return ONLY valid JSON:
{{
"label": "<5-8 word topic label>",
"description": "<one sentence description>",
"pacis_match": "<closest PAJAIS 2019 category, or NOVEL if none>",
"confidence": <0.0-1.0>
}}"""
def _defence_prompt(keyphrases, rep_docs, votes):
kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
v_str = "\n".join(f" LLM{i+1}: {v.get('label','?')}" for i,v in enumerate(votes))
return f"""Resolve this labelling disagreement.
KEYPHRASES: {kp}
Votes:\n{v_str}
Pick the best label or synthesise a better one.
Return ONLY JSON: {{"label":"...","description":"...","pacis_match":"...","confidence":0.0}}"""
def _methodology_prompt(label: str, rep_docs: list[str], regex_summary: str) -> str:
ab = "\n\n".join(f"Paper {i+1}: {d[:500]}" for i,d in enumerate(rep_docs[:3]))
return f"""You are a research methodology auditor for the cluster: "{label}".
REGEX PRE-SCAN (pattern-matched evidence β€” treat as ground truth hints):
{regex_summary}
ABSTRACTS:
{ab}
Your task:
1. Confirm or correct the regex findings using the full abstract context.
2. Assign percentage to each confirmed methodology and technique across 3 papers.
3. Only list items directly supported by the abstracts β€” do NOT hallucinate.
4. empirical_pct + theoretical_pct + mixed_pct must sum to 100.
5. For each item, include a brief evidence quote (≀15 words) from the abstract.
6. In regex_confirmed list each regex label you verified; in regex_rejected list those you disprove.
Return ONLY valid JSON:
{{
"methodologies": [
{{"name":"<methodology>","papers":[1,2,3],"pct":<0-100>,"evidence":"<≀15 word quote>"}}
],
"techniques": [
{{"name":"<technique>","papers":[1,2,3],"pct":<0-100>,"evidence":"<≀15 word quote>"}}
],
"dominant_method": "<single most common>",
"dominant_technique": "<single most common>",
"empirical_pct": <0-100>,
"theoretical_pct": <0-100>,
"mixed_pct": <0-100>,
"regex_confirmed": ["<label1>"],
"regex_rejected": ["<label2>"]
}}"""
def _critic_prompt(label, description, keyphrases, rep_docs):
kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
ab = " | ".join(d[:300] for d in rep_docs[:3])
return f"""You are a strict quality auditor for research topic labels.
CURRENT LABEL: "{label}"
CURRENT DESCRIPTION: "{description}"
KEYPHRASES: {kp}
REPRESENTATIVE ABSTRACTS: {ab}
Audit for: hallucination, vagueness, keyphrase alignment, specificity.
Return ONLY valid JSON:
{{
"refined_label": "<improved 5-8 word label>",
"refined_description": "<one sentence>",
"hallucination_detected": true/false,
"issues": ["<issue1>"],
"improvement_score": <0.0-1.0>,
"confidence": <0.0-1.0>
}}"""
# ============================================================================
# NEW: COMPUTATIONAL TECHNIQUE PROMPTS
# ============================================================================
def _comp_technique_batch_prompt(papers: list[dict], regex_hint: str) -> str:
"""
Prompt fed to each LLM for a batch of methodology-CSV papers.
Papers have keys: paper_idx, title, journal, methodology (text).
regex_hint is the pre-scanned regex evidence for this batch.
"""
batch_text = "\n\n".join(
f"PAPER {p['paper_idx']} [{p['journal']}] β€” {p['title'][:100]}\n"
f"METHODOLOGY TEXT: {p['methodology'][:800]}"
for p in papers
)
paper_ids = [p['paper_idx'] for p in papers]
return f"""You are a computational technique extractor for IS research papers.
REGEX PRE-SCAN (ground truth hints from pattern matching):
{regex_hint}
PAPERS:
{batch_text}
For EACH paper listed above ({paper_ids}), identify the computational techniques used.
A computational technique must be explicitly mentioned or clearly implied in the text.
Do NOT hallucinate β€” if a paper uses no computational technique, return empty list.
Also for each technique found across ALL papers, compute what percentage of papers in this
batch use that technique.
Return ONLY valid JSON:
{{
"per_paper": {{
"<paper_idx>": {{
"techniques": ["<technique1>", "<technique2>"],
"evidence": ["<≀12 word quote1>", "<≀12 word quote2>"],
"confidence": <0.0-1.0>
}}
}},
"batch_technique_pct": {{
"<technique_name>": <percentage_of_papers_in_batch_0-100>
}},
"dominant_technique": "<most common technique in batch>",
"no_technique_papers": [<paper_idxs with no clear computational technique>]
}}"""
def _technique_critique_prompt(technique: str, journal: str, pct_groq: float,
pct_mistral: float, pct_gemini: float,
evidence_samples: list[str]) -> str:
"""Optimization critic for a single consolidated technique label."""
ev = " | ".join(evidence_samples[:3])
return f"""You are a research technique label auditor.
TECHNIQUE: "{technique}"
JOURNAL: {journal}
GROQ extracted it in {pct_groq:.0f}% of papers
MISTRAL extracted it in {pct_mistral:.0f}% of papers
GEMINI extracted it in {pct_gemini:.0f}% of papers
EVIDENCE QUOTES: {ev}
Audit:
1. Is the technique name precise and not hallucinated?
2. Is there inter-LLM disagreement (>15% gap) suggesting ambiguity?
3. Should this be split into sub-techniques or merged with another?
4. Suggest a refined canonical name if needed.
Return ONLY valid JSON:
{{
"refined_name": "<canonical technique name or same if fine>",
"is_hallucination": true/false,
"high_variance_across_llms": true/false,
"suggestion": "<one sentence improvement recommendation>",
"split_into": ["<sub-tech1>", "<sub-tech2>"],
"merge_with": "<other technique name or null>",
"confidence": <0.0-1.0>
}}"""
# ============================================================================
# CONSOLIDATION HELPERS (original + new)
# ============================================================================
def _consolidate_methodology(r1: dict, r2: dict, r3: dict, regex_scan: dict) -> dict:
"""Merge Groq + Mistral + Gemini methodology responses. β‰₯2 LLM gate."""
def _name_map(r, key):
return {item["name"].strip().lower(): item for item in r.get(key, [])}
def _merge_items(key):
maps = [_name_map(r, key) for r in [r1, r2, r3]]
all_keys = set().union(*[m.keys() for m in maps])
accepted, rejected = [], []
for k in all_keys:
voters = [m[k] for m in maps if k in m]
n_votes = len(voters)
avg_pct = round(sum(v.get("pct",0) for v in voters) / n_votes)
papers = sorted({p for v in voters for p in v.get("papers",[])})
evidence= next((v.get("evidence","") for v in voters if v.get("evidence")), "")
row = {"name": voters[0]["name"], "pct": avg_pct, "papers": papers,
"evidence": evidence, "llm_votes": n_votes,
"agreement": "Triple" if n_votes==3 else "Two" if n_votes==2 else "Single"}
(accepted if n_votes >= 2 else rejected).append(row)
return (sorted(accepted, key=lambda x: -x["pct"]),
sorted(rejected, key=lambda x: -x["pct"]))
methods_acc, methods_rej = _merge_items("methodologies")
techniques_acc, techniques_rej = _merge_items("techniques")
emp_avg = round(sum(r.get("empirical_pct", 0) for r in [r1,r2,r3]) / 3)
theo_avg = round(sum(r.get("theoretical_pct",0) for r in [r1,r2,r3]) / 3)
mix_avg = round(sum(r.get("mixed_pct", 0) for r in [r1,r2,r3]) / 3)
confirmed_votes = Counter(item for r in [r1,r2,r3] for item in r.get("regex_confirmed",[]))
rejected_votes = Counter(item for r in [r1,r2,r3] for item in r.get("regex_rejected",[]))
dom_m = Counter(r.get("dominant_method","") for r in [r1,r2,r3] if r).most_common(1)
dom_t = Counter(r.get("dominant_technique","") for r in [r1,r2,r3] if r).most_common(1)
return {
"methodologies": methods_acc, "techniques": techniques_acc,
"rejected_methods": methods_rej, "rejected_techniques": techniques_rej,
"dominant_method": dom_m[0][0] if dom_m else "β€”",
"dominant_technique": dom_t[0][0] if dom_t else "β€”",
"empirical_pct": emp_avg, "theoretical_pct": theo_avg, "mixed_pct": mix_avg,
"regex_confirmed_consensus": [k for k,v in confirmed_votes.items() if v>=2],
"regex_rejected_consensus": [k for k,v in rejected_votes.items() if v>=2],
"llm_raw": {"groq": r1, "mistral": r2, "gemini": r3},
"regex_scan": regex_scan,
}
def _consolidate_comp_techniques(r1: dict, r2: dict, r3: dict,
papers: list[dict]) -> dict:
"""
Consolidate per-paper technique extraction from 3 LLMs.
Rule: a technique is accepted for a paper when β‰₯2 LLMs named it.
Builds per-LLM technique % and consolidated %.
"""
all_paper_ids = [str(p["paper_idx"]) for p in papers]
def _get_per_paper(resp):
return resp.get("per_paper", {})
def _get_batch_pct(resp):
return resp.get("batch_technique_pct", {})
# Per-LLM batch percentages (for LLM sheets)
pct_groq = {k.lower(): v for k,v in _get_batch_pct(r1).items()}
pct_mistral = {k.lower(): v for k,v in _get_batch_pct(r2).items()}
pct_gemini = {k.lower(): v for k,v in _get_batch_pct(r3).items()}
all_tech_keys = set(pct_groq) | set(pct_mistral) | set(pct_gemini)
# β‰₯2 LLM gate for consolidated batch %
consolidated_pct = {}
for tk in all_tech_keys:
vals = [d[tk] for d in [pct_groq, pct_mistral, pct_gemini] if tk in d]
if len(vals) >= 2:
consolidated_pct[tk] = round(sum(vals) / len(vals))
# Per-paper consolidated techniques (β‰₯2 LLMs must name the technique for that paper)
per_paper_groq = _get_per_paper(r1)
per_paper_mistral = _get_per_paper(r2)
per_paper_gemini = _get_per_paper(r3)
per_paper_consolidated = {}
for pid in all_paper_ids:
techs_groq = set(t.lower() for t in per_paper_groq.get(pid, {}).get("techniques", []))
techs_mistral = set(t.lower() for t in per_paper_mistral.get(pid,{}).get("techniques", []))
techs_gemini = set(t.lower() for t in per_paper_gemini.get(pid, {}).get("techniques", []))
# Union of all named techniques
all_named = techs_groq | techs_mistral | techs_gemini
accepted = [t for t in all_named
if sum([t in techs_groq, t in techs_mistral, t in techs_gemini]) >= 2]
per_paper_consolidated[pid] = accepted
dom_g = r1.get("dominant_technique","β€”")
dom_m = r2.get("dominant_technique","β€”")
dom_gem = r3.get("dominant_technique","β€”")
dominant = Counter([dom_g, dom_m, dom_gem]).most_common(1)
return {
"per_paper_consolidated": per_paper_consolidated,
"consolidated_pct": consolidated_pct,
"pct_groq": pct_groq,
"pct_mistral": pct_mistral,
"pct_gemini": pct_gemini,
"dominant_technique": dominant[0][0] if dominant else "β€”",
"raw": {"groq": r1, "mistral": r2, "gemini": r3},
}
# ============================================================================
# GROUNDING + CLEAN
# ============================================================================
def _grounding(label, keyphrases):
if not label or not keyphrases: return {"verdict":"FAIL","score":0}
lt = set(re.findall(r"\b[a-z]{3,}\b", label.lower()))
kt = set()
for k in keyphrases:
kt.update(re.findall(r"\b[a-z]{3,}\b",
(k[0] if isinstance(k,tuple) else k).lower()))
noise = {"the","and","for","with","using","based","from","that","are","this"}
lt -= noise; kt -= noise
m = list(lt & kt)
return {"verdict":"PASS" if m else "FAIL", "score":len(m)/max(len(lt),1), "matched":m}
def _clean(s):
s = str(s or "").replace("\n"," ").strip()
return s[:60].rsplit(" ",1)[0] if len(s)>60 else s
# ============================================================================
# ORIGINAL NODES (DO NOT CHANGE)
# ============================================================================
def embed_and_cluster(state: PipelineState) -> dict:
from tools import run_topic_modeling
try:
td = run_topic_modeling(state["filepath"], state.get("n_trials", 50))
return {"topic_data": td}
except Exception as e:
return {"error": str(e)}
def llm_council(state: PipelineState) -> dict:
td = state["topic_data"]
if not td: return {"error": "No topic data"}
client = Groq(api_key=state["groq_key"], max_retries=0)
mk, gk = state["mistral_key"], state["gemini_key"]
sheets = {1:[], 2:[], 3:[], 4:[]}
interps = {}
for cid in sorted(td["keyphrases"].keys()):
kps = td["keyphrases"][cid]
rds = td["representative_docs"].get(cid, [])
sw = td["membership"].get(cid, {"strong":0,"weak":0})
prompt = _label_prompt(kps, rds)
s1 = _groq(client, prompt); time.sleep(1)
s2 = _mistral(prompt, mk); time.sleep(1)
s3 = _gemini(prompt, gk); time.sleep(4)
votes = [s1, s2, s3]
for sheet_n, resp in [(1,s1),(2,s2),(3,s3)]:
sheets[sheet_n].append({"cluster":cid,
**{k:resp.get(k,"β€”") for k in ["label","description","pacis_match","confidence"]}})
valid = [v for v in votes if v and "label" in v]
labels_l = [_clean(v.get("label","")).lower() for v in valid]
counts = Counter(labels_l)
if any(c>=3 for c in counts.values()):
agreement = "Triple"
winner = max(counts, key=counts.get)
best = next(v for v in valid if _clean(v["label"]).lower()==winner)
elif any(c>=2 for c in counts.values()):
agreement = "Two"
winner = max(counts, key=counts.get)
best = next(v for v in valid if _clean(v["label"]).lower()==winner)
else:
agreement = "Single"
d = _groq(client, _defence_prompt(kps, rds, votes))
best = d if d and "label" in d else (valid[0] if valid else {})
label = _clean(best.get("label",""))
gc = _grounding(label, kps)
if gc["verdict"]=="FAIL" and valid:
label = _clean(valid[0].get("label",""))
cp = td.get("cluster_persistence",{}).get(cid, 0.0)
sheets[4].append({"cluster":cid,"label":label,"agreement":agreement,
"description":best.get("description",""),
"pacis_match":best.get("pacis_match",""),
"strong":sw["strong"],"weak":sw["weak"],
"persistence":round(cp,4),"grounding":gc["verdict"]})
interps[cid] = {"label":label,"agreement":agreement,
"strong":sw["strong"],"weak":sw["weak"],"persistence":cp,
"description":best.get("description",""),
"pacis_match":best.get("pacis_match",""),
"keyphrases":[k[0] if isinstance(k,tuple) else k for k in kps[:5]]}
logger.info("Cluster %d β†’ %s [%s]", cid, label, agreement)
total = len(sheets[4]) or 1
n_triple = sum(1 for r in sheets[4] if r.get("agreement")=="Triple")
n_two = sum(1 for r in sheets[4] if r.get("agreement")=="Two")
rates = {"triple": round(n_triple/total*100),
"two_or_more": round((n_triple+n_two)/total*100),
"single": round((total-n_triple-n_two)/total*100)}
names = {1:"sheet1_groq",2:"sheet2_mistral",3:"sheet3_gemini",4:"sheet4_consolidated"}
sheet_paths = {}
for sn, name in names.items():
path = f"{name}.csv"
pd.DataFrame(sheets[sn]).to_csv(path, index=False)
sheet_paths[sn] = path
with open("topics.json","w") as f: json.dump(sheets[4], f, indent=2)
return {"interpretations":interps,"sheets":sheets,
"agreement_rates":rates,"sheet_paths":sheet_paths,"json_path":"topics.json"}
def optimization_loop(state: PipelineState) -> dict:
n_opt = state.get("n_optimize", 1)
if n_opt <= 1:
return {"refinement_log": []}
client = Groq(api_key=state["groq_key"], max_retries=0)
interps = state.get("interpretations", {})
td = state["topic_data"]
sheets = state.get("sheets", {})
refinement_log = []
for iteration in range(n_opt - 1):
iter_num = iteration + 2
logger.info("Optimization iteration %d / %d", iter_num, n_opt)
for cid in sorted(interps.keys()):
kps = td["keyphrases"].get(cid, [])
rds = td["representative_docs"].get(cid, [])
current_label = interps[cid]["label"]
current_desc = interps[cid].get("description","")
audit = _groq(client, _critic_prompt(current_label, current_desc, kps, rds))
time.sleep(0.8)
if not audit: continue
improvement = audit.get("improvement_score", 0.0)
hallucinated = audit.get("hallucination_detected", False)
new_label = _clean(audit.get("refined_label", current_label))
new_desc = audit.get("refined_description", current_desc)
changed = (new_label.lower() != current_label.lower()) and (
improvement > 0.15 or hallucinated)
if changed and _grounding(new_label, kps)["verdict"] == "PASS":
refinement_log.append({
"cluster": cid, "iteration": iter_num,
"old_label": current_label, "new_label": new_label,
"issues": audit.get("issues",[]),
"improvement_score": round(improvement,3),
"hallucination_detected": hallucinated,
})
interps[cid]["label"] = new_label
interps[cid]["description"] = new_desc
logger.info(" C%d refined: '%s' β†’ '%s'", cid, current_label, new_label)
label_map = {v["cluster"]: v for v in sheets.get(4,[])}
for cid, interp in interps.items():
if cid in label_map:
label_map[cid]["label"] = interp["label"]
return {"interpretations":interps,"sheets":sheets,"refinement_log":refinement_log}
def extract_methodology(state: PipelineState) -> dict:
"""3-LLM council for cluster-level methodology (unchanged logic)."""
td = state["topic_data"]
interps = state.get("interpretations", {})
client = Groq(api_key=state["groq_key"], max_retries=0)
mk, gk = state["mistral_key"], state["gemini_key"]
methodology_data = {}
for cid in sorted(td["keyphrases"].keys()):
rds = td["representative_docs"].get(cid, [])
label = interps.get(cid, {}).get("label", f"Cluster {cid}")
scan = _regex_scan(rds)
regex_hint = _regex_summary(scan)
logger.info("Cluster %d regex: %d method hits, %d technique hits",
cid, len(scan["methods"]), len(scan["techniques"]))
prompt = _methodology_prompt(label, rds, regex_hint)
r1 = _groq(client, prompt); time.sleep(1)
r2 = _mistral(prompt, mk); time.sleep(1)
r3 = _gemini(prompt, gk); time.sleep(4)
consolidated = _consolidate_methodology(r1, r2, r3, scan)
methodology_data[cid] = consolidated
logger.info("Cluster %d β†’ dom_method: %s | dom_tech: %s",
cid, consolidated["dominant_method"], consolidated["dominant_technique"])
return {"methodology_data": methodology_data}
def collect_top_papers(state: PipelineState) -> dict:
td = state["topic_data"]
interps = state.get("interpretations", {})
top_papers = {}
for cid in sorted(interps.keys()):
rds = td["representative_docs"].get(cid, [])
label = interps.get(cid, {}).get("label", f"Cluster {cid}")
papers = []
for rank, doc in enumerate(rds[:3], start=1):
title_part = doc.split(". ")[0][:120] if ". " in doc else doc[:120]
abstract_part = doc[len(title_part):].strip(". ")[:400]
papers.append({"rank":rank,"title":title_part,"abstract_snippet":abstract_part,
"cluster":cid,"cluster_label":label})
top_papers[cid] = papers
return {"top_papers": top_papers}
def build_mismatch(state: PipelineState) -> dict:
from tools import build_mismatch_table
td = state["topic_data"]
interps = state.get("interpretations", {})
labels_map = {cid: v["label"] for cid, v in interps.items()}
return {"mismatch_table": build_mismatch_table(td["keyphrases"], labels_map)}
# ============================================================================
# NEW NODE 1: load_methodology_corpus
# ============================================================================
def load_methodology_corpus(state: PipelineState) -> dict:
"""
Load the methodology CSV (title, doi, methodology).
Detect journal for each paper using JOURNAL_PATTERNS applied to doi + title.
Assigns paper_idx starting at 1.
Returns methodology_papers list ready for embedding and LLM extraction.
"""
fpath = state.get("methodology_filepath")
if not fpath:
logger.info("No methodology CSV provided β€” skipping methodology pipeline.")
return {"methodology_papers": []}
df = pd.read_csv(fpath)
df.columns = df.columns.str.lower()
required = {"title","methodology"}
missing = required - set(df.columns)
if missing:
logger.warning("Methodology CSV missing columns: %s β€” skipping.", missing)
return {"methodology_papers": []}
if "doi" not in df.columns:
df["doi"] = "N/A"
papers = []
for idx, row in df.iterrows():
title = str(row.get("title","") or "")
doi = str(row.get("doi","N/A") or "N/A")
methodology= str(row.get("methodology","") or "")
journal = _detect_journal(doi, title)
papers.append({
"paper_idx": idx + 1,
"title": title,
"doi": doi,
"methodology": methodology,
"journal": journal,
})
journals_found = Counter(p["journal"] for p in papers)
logger.info("Loaded %d methodology papers. Journals: %s", len(papers), dict(journals_found))
return {"methodology_papers": papers}
# ============================================================================
# NEW NODE 2: embed_methodology_vectors
# ============================================================================
def embed_methodology_vectors(state: PipelineState) -> dict:
"""
Embed methodology text as a SEPARATE vector space from the corpus.
Uses the same SPECTER-2 model but applied to methodology text only.
Embeddings stored as a list of lists for JSON-serialisability.
"""
papers = state.get("methodology_papers", [])
if not papers:
return {"methodology_embeddings": []}
from sentence_transformers import SentenceTransformer
texts = [p["methodology"][:1500] for p in papers] # cap at 1500 chars
logger.info("Embedding %d methodology texts with SPECTER-2 (separate vector space)…", len(texts))
model = SentenceTransformer("allenai/specter2_base")
embeddings = model.encode(texts, show_progress_bar=True, batch_size=32)
logger.info("Methodology embeddings: %s", embeddings.shape)
return {"methodology_embeddings": embeddings.tolist()}
# ============================================================================
# NEW NODE 3: extract_comp_techniques (3-LLM Council)
# ============================================================================
def extract_comp_techniques(state: PipelineState) -> dict:
"""
3-LLM Council to extract computational techniques from methodology-CSV papers.
Pipeline per batch of BATCH_SIZE papers:
1. Regex pre-scan β†’ TECHNIQUE_PATTERNS on methodology text
2. Groq call β†’ per-paper techniques + batch %
3. Mistral call β†’ per-paper techniques + batch %
4. Gemini call β†’ per-paper techniques + batch %
5. Consolidate β†’ β‰₯2 LLM gate per (paper, technique)
Produces 4 sheets (mirroring topic sheets):
Sheet 1 = Groq raw
Sheet 2 = Mistral raw
Sheet 3 = Gemini raw
Sheet 4 = Consolidated (β‰₯2 LLM agreement)
"""
papers = state.get("methodology_papers", [])
if not papers:
return {"comp_technique_sheets": {1:[], 2:[], 3:[], 4:[]}}
client = Groq(api_key=state["groq_key"], max_retries=0)
mk, gk = state["mistral_key"], state["gemini_key"]
BATCH_SIZE = 5
sheets = {1:[], 2:[], 3:[], 4:[]}
# Accumulate consolidated per-paper techniques across batches
all_consolidated = {} # {paper_idx: [technique_names]}
for batch_start in range(0, len(papers), BATCH_SIZE):
batch = papers[batch_start: batch_start + BATCH_SIZE]
batch_texts = [p["methodology"][:1500] for p in batch]
# Step 1 β€” regex pre-scan on batch
scan = _regex_scan(batch_texts)
regex_hint = _regex_summary(scan)
logger.info("Batch %d-%d | regex: %d tech hits",
batch[0]["paper_idx"], batch[-1]["paper_idx"], len(scan["techniques"]))
# Step 2 β€” 3 LLM calls
prompt = _comp_technique_batch_prompt(batch, regex_hint)
r1 = _groq(client, prompt); time.sleep(1)
r2 = _mistral(prompt, mk); time.sleep(1)
r3 = _gemini(prompt, gk); time.sleep(4)
# Step 3 β€” consolidate
consolidated = _consolidate_comp_techniques(r1, r2, r3, batch)
# Build sheet rows β€” one row per paper per LLM
for p in batch:
pid = str(p["paper_idx"])
journal = p["journal"]
title = p["title"][:80]
def _fmt_llm(resp):
pp = resp.get("per_paper", {}).get(pid, {})
return {
"paper_idx": p["paper_idx"],
"title": title,
"journal": journal,
"techniques": ", ".join(pp.get("techniques", [])) or "β€”",
"evidence": " | ".join(pp.get("evidence", []))[:200] or "β€”",
"confidence":pp.get("confidence","β€”"),
}
sheets[1].append(_fmt_llm(r1))
sheets[2].append(_fmt_llm(r2))
sheets[3].append(_fmt_llm(r3))
con_techs = consolidated["per_paper_consolidated"].get(pid, [])
sheets[4].append({
"paper_idx": p["paper_idx"],
"title": title,
"journal": journal,
"techniques": ", ".join(con_techs) or "β€”",
"n_techniques": len(con_techs),
"dominant": consolidated["dominant_technique"],
})
all_consolidated[p["paper_idx"]] = con_techs
logger.info("Batch consolidated dominant: %s", consolidated["dominant_technique"])
# Save 4 sheets as CSV
sheet_names = {1:"tech_sheet1_groq",2:"tech_sheet2_mistral",
3:"tech_sheet3_gemini",4:"tech_sheet4_consolidated"}
for sn, name in sheet_names.items():
pd.DataFrame(sheets[sn]).to_csv(f"{name}.csv", index=False)
# Attach per_paper_consolidated back to papers for cross-tab use
for p in papers:
p["consolidated_techniques"] = all_consolidated.get(p["paper_idx"], [])
return {
"comp_technique_sheets": sheets,
"methodology_papers": papers, # updated with consolidated_techniques
}
# ============================================================================
# NEW NODE 4: build_journal_crosstab
# ============================================================================
def build_journal_crosstab(state: PipelineState) -> dict:
"""
Build a technique Γ— journal cross-tabulation.
For each journal in the methodology CSV, compute what % of papers in that
journal mention each consolidated technique.
Also produces per-LLM technique percentage tables for inter-LLM comparison.
"""
papers = state.get("methodology_papers", [])
if not papers:
return {"journal_crosstab": {}}
sheets = state.get("comp_technique_sheets", {})
# --- Consolidated cross-tab ---
journal_tech_counts = defaultdict(lambda: defaultdict(int))
journal_paper_counts = defaultdict(int)
for p in papers:
journal = p["journal"]
journal_paper_counts[journal] += 1
for tech in p.get("consolidated_techniques", []):
journal_tech_counts[journal][tech.title()] += 1
journals = sorted(journal_paper_counts.keys())
all_techniques = sorted({t for j in journal_tech_counts.values() for t in j.keys()})
crosstab = {}
for journal in journals:
n = journal_paper_counts[journal] or 1
crosstab[journal] = {
tech: round(journal_tech_counts[journal].get(tech, 0) / n * 100)
for tech in all_techniques
}
# --- Per-LLM technique frequency across ALL papers ---
def _llm_tech_freq(sheet_rows: list) -> dict:
tech_count = defaultdict(int)
n_papers = len(sheet_rows) or 1
for row in sheet_rows:
raw = row.get("techniques","")
for t in (raw.split(", ") if raw and raw != "β€”" else []):
tech_count[t.strip().title()] += 1
return {t: round(c/n_papers*100) for t,c in tech_count.items()}
per_llm_freq = {
"Groq": _llm_tech_freq(sheets.get(1,[])),
"Mistral": _llm_tech_freq(sheets.get(2,[])),
"Gemini": _llm_tech_freq(sheets.get(3,[])),
}
logger.info("Journal crosstab: %d journals Γ— %d techniques",
len(journals), len(all_techniques))
return {
"journal_crosstab": {
"consolidated": crosstab,
"journals": journals,
"techniques": all_techniques,
"journal_paper_counts": dict(journal_paper_counts),
"per_llm_freq": per_llm_freq,
}
}
# ============================================================================
# NEW NODE 5: optimize_technique_labels
# ============================================================================
def optimize_technique_labels(state: PipelineState) -> dict:
"""
Optimization / improvement pass for computational technique labels.
Runs Groq critic on each consolidated technique found across all journals.
Checks: hallucination, high inter-LLM variance, merge/split suggestions.
Stores improvement suggestions in technique_opt_log for display in UI.
Only applies optimisation when n_optimize > 1.
"""
n_opt = state.get("n_optimize", 1)
if n_opt <= 1:
return {"technique_opt_log": []}
crosstab_data = state.get("journal_crosstab", {})
all_techniques = crosstab_data.get("techniques", [])
if not all_techniques:
return {"technique_opt_log": []}
client = Groq(api_key=state["groq_key"], max_retries=0)
per_llm = crosstab_data.get("per_llm_freq", {})
papers = state.get("methodology_papers", [])
opt_log = []
# Sample evidence quotes for each technique from methodology texts
def _evidence_for(technique: str) -> list[str]:
tech_lower = technique.lower()
samples = []
for p in papers[:30]: # cap at first 30 papers for speed
text = p.get("methodology","")
for pat in TECHNIQUE_PATTERNS.values():
for m in pat.finditer(text):
if tech_lower in m.group(0).lower() or technique.lower() in tech_lower:
snippet = text[max(0,m.start()-40):m.end()+40].replace("\n"," ")
samples.append(snippet[:120])
if len(samples) >= 3:
break
return samples[:3]
for tech in all_techniques:
pct_g = per_llm.get("Groq",{}).get(tech, 0)
pct_m = per_llm.get("Mistral",{}).get(tech, 0)
pct_gem = per_llm.get("Gemini",{}).get(tech, 0)
evidence= _evidence_for(tech)
# Only run critique if there is meaningful inter-LLM variance or low confidence
max_pct = max(pct_g, pct_m, pct_gem)
min_pct = min(pct_g, pct_m, pct_gem)
run_critique = (max_pct - min_pct) > 15 or max_pct < 20
if not run_critique:
continue
critique = _groq(client,
_technique_critique_prompt(tech, "All Journals", pct_g, pct_m, pct_gem, evidence))
time.sleep(0.8)
if not critique:
continue
opt_log.append({
"technique": tech,
"refined_name": critique.get("refined_name", tech),
"is_hallucination": critique.get("is_hallucination", False),
"high_variance": critique.get("high_variance_across_llms", False),
"suggestion": critique.get("suggestion","β€”"),
"split_into": ", ".join(critique.get("split_into",[]) or []) or "β€”",
"merge_with": critique.get("merge_with","β€”") or "β€”",
"pct_groq": pct_g,
"pct_mistral": pct_m,
"pct_gemini": pct_gem,
"confidence": critique.get("confidence", 0),
})
logger.info("Technique opt: '%s' β†’ '%s'", tech, critique.get("refined_name",tech))
return {"technique_opt_log": opt_log}
# ============================================================================
# GRAPH ASSEMBLY
# ============================================================================
def build_graph() -> StateGraph:
g = StateGraph(PipelineState)
# ── original nodes ───────────────────────────────────────────────────────
g.add_node("embed_and_cluster", embed_and_cluster)
g.add_node("llm_council", llm_council)
g.add_node("optimization_loop", optimization_loop)
g.add_node("extract_methodology", extract_methodology)
g.add_node("collect_top_papers", collect_top_papers)
g.add_node("build_mismatch", build_mismatch)
# ── new methodology-CSV nodes ─────────────────────────────────────────────
g.add_node("load_methodology_corpus", load_methodology_corpus)
g.add_node("embed_methodology_vectors", embed_methodology_vectors)
g.add_node("extract_comp_techniques", extract_comp_techniques)
g.add_node("build_journal_crosstab", build_journal_crosstab)
g.add_node("optimize_technique_labels", optimize_technique_labels)
# ── original edges (unchanged) ────────────────────────────────────────────
g.set_entry_point("embed_and_cluster")
g.add_edge("embed_and_cluster", "llm_council")
g.add_edge("llm_council", "optimization_loop")
g.add_edge("optimization_loop", "extract_methodology")
g.add_edge("extract_methodology", "collect_top_papers")
g.add_edge("collect_top_papers", "build_mismatch")
# ── new edges: methodology CSV pipeline runs after core pipeline ──────────
g.add_edge("build_mismatch", "load_methodology_corpus")
g.add_edge("load_methodology_corpus", "embed_methodology_vectors")
g.add_edge("embed_methodology_vectors", "extract_comp_techniques")
g.add_edge("extract_comp_techniques", "build_journal_crosstab")
g.add_edge("build_journal_crosstab", "optimize_technique_labels")
g.add_edge("optimize_technique_labels", END)
return g.compile()
pipeline_graph = build_graph()
def run_pipeline(filepath, groq_key, mistral_key, gemini_key,
n_trials=50, n_optimize=1, methodology_filepath=None):
"""Convenience wrapper β€” methodology_filepath is optional."""
return pipeline_graph.invoke({
"filepath": filepath,
"groq_key": groq_key,
"mistral_key": mistral_key,
"gemini_key": gemini_key,
"n_trials": n_trials,
"n_optimize": n_optimize,
"methodology_filepath": methodology_filepath,
})