BHAVIKBANKER commited on
Commit
64049b0
Β·
verified Β·
1 Parent(s): 853e1a5

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +611 -158
agent.py CHANGED
@@ -1,7 +1,12 @@
1
  """
2
  agent.py β€” LangGraph-based topic analysis agent (Β§11).
3
- 3-LLM Council for topic modelling + methodology extraction.
4
- Methodology pipeline: Regex pre-scan β†’ Groq β†’ Mistral β†’ Gemini β†’ consolidation.
 
 
 
 
 
5
  """
6
  from __future__ import annotations
7
  import json, logging, os, re, time
@@ -17,9 +22,9 @@ logger = logging.getLogger(__name__)
17
  GROQ_MODEL = "llama-3.1-8b-instant"
18
  MISTRAL_MODEL = "mistral-small-latest"
19
 
20
- # ---------------------------------------------------------------------------
21
- # Regex pattern banks β€” transparent extraction shown in UI
22
- # ---------------------------------------------------------------------------
23
  METHODOLOGY_PATTERNS = {
24
  "Survey / Systematic Review": re.compile(
25
  r"\b(survey|systematic\s+review|literature\s+review|bibliometric|scoping\s+review|meta.?analysis)\b", re.I),
@@ -66,6 +71,12 @@ TECHNIQUE_PATTERNS = {
66
  r"\b(reinforcement\s+learning|Q.learning|policy\s+gradient|reward\s+function|Markov\s+decision)\b", re.I),
67
  "Cloud / Big Data": re.compile(
68
  r"\b(cloud\s+computing|Hadoop|Spark|MapReduce|big\s+data|distributed\s+computing|edge\s+computing)\b", re.I),
 
 
 
 
 
 
69
  }
70
 
71
  ORIENTATION_PATTERNS = {
@@ -74,13 +85,30 @@ ORIENTATION_PATTERNS = {
74
  "mixed": re.compile(r"\b(mixed\s+method|qualitative.+quantitative|both|triangulat)\b", re.I),
75
  }
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
 
78
  def _regex_scan(docs: list[str]) -> dict:
79
- """
80
- Run all pattern banks against a list of documents.
81
- Returns hits per category with matched spans so the UI can show exactly
82
- which text triggered each label.
83
- """
84
  method_hits = defaultdict(list)
85
  technique_hits = defaultdict(list)
86
  orientation_counts = {"empirical": 0, "theoretical": 0, "mixed": 0}
@@ -88,18 +116,19 @@ def _regex_scan(docs: list[str]) -> dict:
88
  for doc_idx, doc in enumerate(docs):
89
  for label, pat in METHODOLOGY_PATTERNS.items():
90
  for m in pat.finditer(doc):
91
- method_hits[label].append({"doc": doc_idx + 1, "match": m.group(0),
92
- "span": [m.start(), m.end()]})
 
93
  for label, pat in TECHNIQUE_PATTERNS.items():
94
  for m in pat.finditer(doc):
95
- technique_hits[label].append({"doc": doc_idx + 1, "match": m.group(0),
96
- "span": [m.start(), m.end()]})
 
97
  for orient, pat in ORIENTATION_PATTERNS.items():
98
  if pat.search(doc):
99
  orientation_counts[orient] += 1
100
 
101
  total_orient = sum(orientation_counts.values()) or 1
102
-
103
  return {
104
  "methods": {k: v for k, v in method_hits.items() if v},
105
  "techniques": {k: v for k, v in technique_hits.items() if v},
@@ -111,33 +140,42 @@ def _regex_scan(docs: list[str]) -> dict:
111
  "patterns_applied": {
112
  "methodology": list(METHODOLOGY_PATTERNS.keys()),
113
  "technique": list(TECHNIQUE_PATTERNS.keys()),
114
- "orientation": list(ORIENTATION_PATTERNS.keys()),
115
  },
116
  }
117
 
118
 
119
  def _regex_summary(scan: dict) -> str:
120
- """Human-readable summary of regex hits β€” injected into LLM prompt as evidence."""
121
  lines = []
122
  if scan["methods"]:
123
  lines.append("REGEX-DETECTED METHODOLOGIES:")
124
  for k, hits in scan["methods"].items():
125
- unique_matches = list(dict.fromkeys(h["match"] for h in hits))[:3]
126
  papers = sorted({h["doc"] for h in hits})
127
- lines.append(f" β€’ {k} β€” matched: {unique_matches} (papers: {papers})")
128
  if scan["techniques"]:
129
  lines.append("REGEX-DETECTED TECHNIQUES:")
130
  for k, hits in scan["techniques"].items():
131
- unique_matches = list(dict.fromkeys(h["match"] for h in hits))[:3]
132
  papers = sorted({h["doc"] for h in hits})
133
- lines.append(f" β€’ {k} β€” matched: {unique_matches} (papers: {papers})")
134
- return "\n".join(lines) or "No regex hits found β€” rely on abstracts alone."
135
 
136
 
137
- # ---------------------------------------------------------------------------
138
- # LangGraph state
139
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
140
  class PipelineState(TypedDict, total=False):
 
141
  filepath: str
142
  groq_key: str
143
  mistral_key: str
@@ -154,11 +192,18 @@ class PipelineState(TypedDict, total=False):
154
  refinement_log: list
155
  json_path: str
156
  error: str
157
-
158
-
159
- # ---------------------------------------------------------------------------
160
- # API helpers
161
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
162
  def _parse(raw: str) -> dict:
163
  raw = raw.strip().replace("```json","").replace("```","").strip()
164
  s, e = raw.find("{"), raw.rfind("}")+1
@@ -169,7 +214,7 @@ def _parse(raw: str) -> dict:
169
  def _groq(client, prompt):
170
  try:
171
  r = client.chat.completions.create(model=GROQ_MODEL,
172
- messages=[{"role":"user","content":prompt}], temperature=0.2, timeout=15)
173
  return _parse(r.choices[0].message.content)
174
  except Exception as e: logger.warning("Groq: %s", e); return {}
175
 
@@ -179,7 +224,7 @@ def _mistral(prompt, key):
179
  r = requests.post("https://api.mistral.ai/v1/chat/completions",
180
  headers={"Authorization":f"Bearer {key}","Content-Type":"application/json"},
181
  json={"model":MISTRAL_MODEL,"messages":[{"role":"user","content":prompt}],
182
- "temperature":0.2}, timeout=15)
183
  return _parse(r.json()["choices"][0]["message"]["content"])
184
  except Exception as e: logger.warning("Mistral: %s", e); return {}
185
 
@@ -200,7 +245,8 @@ def _gemini(prompt, key):
200
  msg = err.get("message","") if isinstance(err,dict) else str(err)
201
  if "quota" in msg.lower() or "rate" in msg.lower():
202
  wait = min(40, 10*(attempt+1))
203
- logger.warning("Gemini rate-limited, waiting %ds…", wait); time.sleep(wait); continue
 
204
  logger.warning("Gemini attempt %d: %s", attempt+1, msg); return {}
205
  return _parse(d["candidates"][0]["content"]["parts"][0]["text"])
206
  except Exception as e:
@@ -208,9 +254,9 @@ def _gemini(prompt, key):
208
  return {}
209
 
210
 
211
- # ---------------------------------------------------------------------------
212
- # Topic labelling prompts
213
- # ---------------------------------------------------------------------------
214
  def _label_prompt(keyphrases, rep_docs):
215
  kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
216
  ab = " | ".join(a[:250] for a in rep_docs[:3])
@@ -235,10 +281,6 @@ Votes:\n{v_str}
235
  Pick the best label or synthesise a better one.
236
  Return ONLY JSON: {{"label":"...","description":"...","pacis_match":"...","confidence":0.0}}"""
237
 
238
-
239
- # ---------------------------------------------------------------------------
240
- # Methodology prompt β€” seeded with regex evidence
241
- # ---------------------------------------------------------------------------
242
  def _methodology_prompt(label: str, rep_docs: list[str], regex_summary: str) -> str:
243
  ab = "\n\n".join(f"Paper {i+1}: {d[:500]}" for i,d in enumerate(rep_docs[:3]))
244
  return f"""You are a research methodology auditor for the cluster: "{label}".
@@ -274,42 +316,131 @@ Return ONLY valid JSON:
274
  "regex_rejected": ["<label2>"]
275
  }}"""
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # ---------------------------------------------------------------------------
279
- # Consolidate 3-LLM methodology responses
280
- # ---------------------------------------------------------------------------
281
- def _consolidate_methodology(r1: dict, r2: dict, r3: dict, regex_scan: dict) -> dict:
 
282
  """
283
- Merge Groq + Mistral + Gemini methodology responses.
284
- Rule: a method/technique survives only when β‰₯2 LLMs named it.
285
- Percentage = average across agreeing LLMs.
286
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  def _name_map(r, key):
288
  return {item["name"].strip().lower(): item for item in r.get(key, [])}
289
 
290
  def _merge_items(key):
291
- maps = [_name_map(r, key) for r in [r1, r2, r3]]
292
  all_keys = set().union(*[m.keys() for m in maps])
293
  accepted, rejected = [], []
294
  for k in all_keys:
295
- voters = [m[k] for m in maps if k in m]
296
- n_votes = len(voters)
297
- avg_pct = round(sum(v.get("pct",0) for v in voters) / n_votes)
298
- papers = sorted({p for v in voters for p in v.get("papers",[])})
299
- evidence = next((v.get("evidence","") for v in voters if v.get("evidence")), "")
300
- row = {
301
- "name": voters[0]["name"],
302
- "pct": avg_pct,
303
- "papers": papers,
304
- "evidence": evidence,
305
- "llm_votes": n_votes,
306
- "agreement": "Triple" if n_votes==3 else "Two" if n_votes==2 else "Single",
307
- }
308
  (accepted if n_votes >= 2 else rejected).append(row)
309
  return (sorted(accepted, key=lambda x: -x["pct"]),
310
  sorted(rejected, key=lambda x: -x["pct"]))
311
 
312
- methods_acc, methods_rej = _merge_items("methodologies")
313
  techniques_acc, techniques_rej = _merge_items("techniques")
314
 
315
  emp_avg = round(sum(r.get("empirical_pct", 0) for r in [r1,r2,r3]) / 3)
@@ -318,20 +449,15 @@ def _consolidate_methodology(r1: dict, r2: dict, r3: dict, regex_scan: dict) ->
318
 
319
  confirmed_votes = Counter(item for r in [r1,r2,r3] for item in r.get("regex_confirmed",[]))
320
  rejected_votes = Counter(item for r in [r1,r2,r3] for item in r.get("regex_rejected",[]))
321
-
322
  dom_m = Counter(r.get("dominant_method","") for r in [r1,r2,r3] if r).most_common(1)
323
  dom_t = Counter(r.get("dominant_technique","") for r in [r1,r2,r3] if r).most_common(1)
324
 
325
  return {
326
- "methodologies": methods_acc,
327
- "techniques": techniques_acc,
328
- "rejected_methods": methods_rej,
329
- "rejected_techniques":techniques_rej,
330
  "dominant_method": dom_m[0][0] if dom_m else "β€”",
331
  "dominant_technique": dom_t[0][0] if dom_t else "β€”",
332
- "empirical_pct": emp_avg,
333
- "theoretical_pct": theo_avg,
334
- "mixed_pct": mix_avg,
335
  "regex_confirmed_consensus": [k for k,v in confirmed_votes.items() if v>=2],
336
  "regex_rejected_consensus": [k for k,v in rejected_votes.items() if v>=2],
337
  "llm_raw": {"groq": r1, "mistral": r2, "gemini": r3},
@@ -339,32 +465,70 @@ def _consolidate_methodology(r1: dict, r2: dict, r3: dict, regex_scan: dict) ->
339
  }
340
 
341
 
342
- # ---------------------------------------------------------------------------
343
- # Critic prompt for optimization loop
344
- # ---------------------------------------------------------------------------
345
- def _critic_prompt(label, description, keyphrases, rep_docs):
346
- kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
347
- ab = " | ".join(d[:300] for d in rep_docs[:3])
348
- return f"""You are a strict quality auditor for research topic labels.
349
- CURRENT LABEL: "{label}"
350
- CURRENT DESCRIPTION: "{description}"
351
- KEYPHRASES: {kp}
352
- REPRESENTATIVE ABSTRACTS: {ab}
353
- Audit for: hallucination, vagueness, keyphrase alignment, specificity.
354
- Return ONLY valid JSON:
355
- {{
356
- "refined_label": "<improved 5-8 word label>",
357
- "refined_description": "<one sentence>",
358
- "hallucination_detected": true/false,
359
- "issues": ["<issue1>"],
360
- "improvement_score": <0.0-1.0>,
361
- "confidence": <0.0-1.0>
362
- }}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
 
 
 
 
 
 
 
 
 
364
 
365
- # ---------------------------------------------------------------------------
366
- # Grounding check
367
- # ---------------------------------------------------------------------------
 
368
  def _grounding(label, keyphrases):
369
  if not label or not keyphrases: return {"verdict":"FAIL","score":0}
370
  lt = set(re.findall(r"\b[a-z]{3,}\b", label.lower()))
@@ -382,9 +546,9 @@ def _clean(s):
382
  return s[:60].rsplit(" ",1)[0] if len(s)>60 else s
383
 
384
 
385
- # ---------------------------------------------------------------------------
386
- # Node: embed + cluster
387
- # ---------------------------------------------------------------------------
388
  def embed_and_cluster(state: PipelineState) -> dict:
389
  from tools import run_topic_modeling
390
  try:
@@ -394,9 +558,6 @@ def embed_and_cluster(state: PipelineState) -> dict:
394
  return {"error": str(e)}
395
 
396
 
397
- # ---------------------------------------------------------------------------
398
- # Node: LLM Council (labels)
399
- # ---------------------------------------------------------------------------
400
  def llm_council(state: PipelineState) -> dict:
401
  td = state["topic_data"]
402
  if not td: return {"error": "No topic data"}
@@ -454,7 +615,6 @@ def llm_council(state: PipelineState) -> dict:
454
  "description":best.get("description",""),
455
  "pacis_match":best.get("pacis_match",""),
456
  "keyphrases":[k[0] if isinstance(k,tuple) else k for k in kps[:5]]}
457
-
458
  logger.info("Cluster %d β†’ %s [%s]", cid, label, agreement)
459
 
460
  total = len(sheets[4]) or 1
@@ -471,14 +631,10 @@ def llm_council(state: PipelineState) -> dict:
471
  pd.DataFrame(sheets[sn]).to_csv(path, index=False)
472
  sheet_paths[sn] = path
473
  with open("topics.json","w") as f: json.dump(sheets[4], f, indent=2)
474
-
475
  return {"interpretations":interps,"sheets":sheets,
476
  "agreement_rates":rates,"sheet_paths":sheet_paths,"json_path":"topics.json"}
477
 
478
 
479
- # ---------------------------------------------------------------------------
480
- # Node: optimization / refinement loop
481
- # ---------------------------------------------------------------------------
482
  def optimization_loop(state: PipelineState) -> dict:
483
  n_opt = state.get("n_optimize", 1)
484
  if n_opt <= 1:
@@ -493,7 +649,6 @@ def optimization_loop(state: PipelineState) -> dict:
493
  for iteration in range(n_opt - 1):
494
  iter_num = iteration + 2
495
  logger.info("Optimization iteration %d / %d", iter_num, n_opt)
496
-
497
  for cid in sorted(interps.keys()):
498
  kps = td["keyphrases"].get(cid, [])
499
  rds = td["representative_docs"].get(cid, [])
@@ -527,29 +682,11 @@ def optimization_loop(state: PipelineState) -> dict:
527
  for cid, interp in interps.items():
528
  if cid in label_map:
529
  label_map[cid]["label"] = interp["label"]
530
-
531
  return {"interpretations":interps,"sheets":sheets,"refinement_log":refinement_log}
532
 
533
 
534
- # ---------------------------------------------------------------------------
535
- # Node: 3-LLM methodology council + regex pre-scan
536
- # ---------------------------------------------------------------------------
537
  def extract_methodology(state: PipelineState) -> dict:
538
- """
539
- Per cluster:
540
- 1. Run METHODOLOGY_PATTERNS + TECHNIQUE_PATTERNS regex banks against
541
- representative abstracts β€” produces ground-truth evidence with exact
542
- match spans that are surfaced in the UI.
543
- 2. Build a human-readable regex summary and inject it into the LLM prompt
544
- as grounding evidence.
545
- 3. Call Groq, Mistral, and Gemini with the same prompt β€” each LLM must
546
- confirm or reject the regex hits, and may add anything it finds in
547
- the full abstract text.
548
- 4. Consolidate: only methods/techniques agreed by β‰₯2 LLMs survive.
549
- Percentages are averaged across agreeing LLMs.
550
- 5. Store full trace: regex_scan, per-LLM raw responses, consolidation
551
- result β€” all exposed in the UI extraction pipeline tab.
552
- """
553
  td = state["topic_data"]
554
  interps = state.get("interpretations", {})
555
  client = Groq(api_key=state["groq_key"], max_retries=0)
@@ -559,41 +696,25 @@ def extract_methodology(state: PipelineState) -> dict:
559
  for cid in sorted(td["keyphrases"].keys()):
560
  rds = td["representative_docs"].get(cid, [])
561
  label = interps.get(cid, {}).get("label", f"Cluster {cid}")
562
-
563
- # Step 1 β€” regex pre-scan
564
  scan = _regex_scan(rds)
565
  regex_hint = _regex_summary(scan)
566
  logger.info("Cluster %d regex: %d method hits, %d technique hits",
567
  cid, len(scan["methods"]), len(scan["techniques"]))
568
-
569
- # Step 2 β€” all 3 LLMs with regex evidence in prompt
570
  prompt = _methodology_prompt(label, rds, regex_hint)
571
  r1 = _groq(client, prompt); time.sleep(1)
572
  r2 = _mistral(prompt, mk); time.sleep(1)
573
  r3 = _gemini(prompt, gk); time.sleep(4)
574
-
575
- logger.info("Cluster %d methodology votes β€” Groq:%s Mistral:%s Gemini:%s",
576
- cid, bool(r1.get("methodologies")),
577
- bool(r2.get("methodologies")), bool(r3.get("methodologies")))
578
-
579
- # Step 3 β€” consolidate with β‰₯2 LLM agreement rule
580
  consolidated = _consolidate_methodology(r1, r2, r3, scan)
581
  methodology_data[cid] = consolidated
582
-
583
  logger.info("Cluster %d β†’ dom_method: %s | dom_tech: %s",
584
  cid, consolidated["dominant_method"], consolidated["dominant_technique"])
585
-
586
  return {"methodology_data": methodology_data}
587
 
588
 
589
- # ---------------------------------------------------------------------------
590
- # Node: top-3 papers
591
- # ---------------------------------------------------------------------------
592
  def collect_top_papers(state: PipelineState) -> dict:
593
  td = state["topic_data"]
594
  interps = state.get("interpretations", {})
595
  top_papers = {}
596
-
597
  for cid in sorted(interps.keys()):
598
  rds = td["representative_docs"].get(cid, [])
599
  label = interps.get(cid, {}).get("label", f"Cluster {cid}")
@@ -601,17 +722,12 @@ def collect_top_papers(state: PipelineState) -> dict:
601
  for rank, doc in enumerate(rds[:3], start=1):
602
  title_part = doc.split(". ")[0][:120] if ". " in doc else doc[:120]
603
  abstract_part = doc[len(title_part):].strip(". ")[:400]
604
- papers.append({"rank":rank,"title":title_part,
605
- "abstract_snippet":abstract_part,
606
  "cluster":cid,"cluster_label":label})
607
  top_papers[cid] = papers
608
-
609
  return {"top_papers": top_papers}
610
 
611
 
612
- # ---------------------------------------------------------------------------
613
- # Node: mismatch table
614
- # ---------------------------------------------------------------------------
615
  def build_mismatch(state: PipelineState) -> dict:
616
  from tools import build_mismatch_table
617
  td = state["topic_data"]
@@ -620,35 +736,372 @@ def build_mismatch(state: PipelineState) -> dict:
620
  return {"mismatch_table": build_mismatch_table(td["keyphrases"], labels_map)}
621
 
622
 
623
- # ---------------------------------------------------------------------------
624
- # Build LangGraph
625
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
  def build_graph() -> StateGraph:
627
  g = StateGraph(PipelineState)
 
 
628
  g.add_node("embed_and_cluster", embed_and_cluster)
629
  g.add_node("llm_council", llm_council)
630
  g.add_node("optimization_loop", optimization_loop)
631
  g.add_node("extract_methodology", extract_methodology)
632
  g.add_node("collect_top_papers", collect_top_papers)
633
  g.add_node("build_mismatch", build_mismatch)
 
 
 
 
 
 
 
 
 
634
  g.set_entry_point("embed_and_cluster")
635
  g.add_edge("embed_and_cluster", "llm_council")
636
  g.add_edge("llm_council", "optimization_loop")
637
  g.add_edge("optimization_loop", "extract_methodology")
638
  g.add_edge("extract_methodology", "collect_top_papers")
639
  g.add_edge("collect_top_papers", "build_mismatch")
640
- g.add_edge("build_mismatch", END)
 
 
 
 
 
 
 
 
641
  return g.compile()
642
 
 
643
  pipeline_graph = build_graph()
644
 
 
645
  def run_pipeline(filepath, groq_key, mistral_key, gemini_key,
646
- n_trials=50, n_optimize=1):
 
647
  return pipeline_graph.invoke({
648
- "filepath": filepath,
649
- "groq_key": groq_key,
650
- "mistral_key": mistral_key,
651
- "gemini_key": gemini_key,
652
- "n_trials": n_trials,
653
- "n_optimize": n_optimize,
 
654
  })
 
1
  """
2
  agent.py β€” LangGraph-based topic analysis agent (Β§11).
3
+ Original 3-LLM Council for topic modelling is UNCHANGED.
4
+ NEW nodes appended:
5
+ - load_methodology_corpus : load methodology CSV, detect journal per paper
6
+ - embed_methodology_vectors : SPECTER-2 embed methodology text (separate vector space)
7
+ - extract_comp_techniques : 3-LLM council (regex β†’ Groq β†’ Mistral β†’ Gemini β†’ consolidate)
8
+ - build_journal_crosstab : technique Γ— journal cross-tabulation with percentages
9
+ - optimize_technique_labels : improvement / hallucination critique on consolidated techniques
10
  """
11
  from __future__ import annotations
12
  import json, logging, os, re, time
 
22
  GROQ_MODEL = "llama-3.1-8b-instant"
23
  MISTRAL_MODEL = "mistral-small-latest"
24
 
25
+ # ============================================================================
26
+ # REGEX BANKS (used in both cluster methodology AND methodology-CSV pipeline)
27
+ # ============================================================================
28
  METHODOLOGY_PATTERNS = {
29
  "Survey / Systematic Review": re.compile(
30
  r"\b(survey|systematic\s+review|literature\s+review|bibliometric|scoping\s+review|meta.?analysis)\b", re.I),
 
71
  r"\b(reinforcement\s+learning|Q.learning|policy\s+gradient|reward\s+function|Markov\s+decision)\b", re.I),
72
  "Cloud / Big Data": re.compile(
73
  r"\b(cloud\s+computing|Hadoop|Spark|MapReduce|big\s+data|distributed\s+computing|edge\s+computing)\b", re.I),
74
+ "Structural Equation Modelling": re.compile(
75
+ r"\b(structural\s+equation|SEM|PLS.SEM|covariance.based|CB.SEM|partial\s+least\s+squares)\b", re.I),
76
+ "Time Series / VAR": re.compile(
77
+ r"\b(time\s+series|VAR\b|vector\s+auto.?regression|VARX|ARIMA|impulse\s+response|Granger)\b", re.I),
78
+ "Content Analysis / Coding": re.compile(
79
+ r"\b(content\s+analysis|coding\s+scheme|thematic\s+analys|grounded\s+theory|open\s+coding|axial\s+coding)\b", re.I),
80
  }
81
 
82
  ORIENTATION_PATTERNS = {
 
85
  "mixed": re.compile(r"\b(mixed\s+method|qualitative.+quantitative|both|triangulat)\b", re.I),
86
  }
87
 
88
+ # Journal detection patterns applied to DOI + title
89
+ JOURNAL_PATTERNS = {
90
+ "MISQ": re.compile(
91
+ r"(misq|mis\s*quarterly|10\.25300|10\.2307/[0-9]{8}|MIS\s+Quarterly)", re.I),
92
+ "JAIS": re.compile(
93
+ r"(jais|10\.17705/1jais|journal.*association.*information\s+systems)", re.I),
94
+ "ISR": re.compile(
95
+ r"(10\.1287/isre|\bisr\b|information\s+systems\s+research)", re.I),
96
+ "JMIS": re.compile(
97
+ r"(10\.1080/07421222|jmis|journal.*management.*information\s+systems)", re.I),
98
+ "PAJAIS": re.compile(
99
+ r"(pajais|pacific.*asia.*information|10\.17705/2asfp)", re.I),
100
+ "ECIS": re.compile(
101
+ r"(ecis|european.*conference.*information\s+systems)", re.I),
102
+ "ICIS": re.compile(
103
+ r"(icis|international.*conference.*information\s+systems)", re.I),
104
+ }
105
+
106
 
107
+ # ============================================================================
108
+ # SHARED REGEX HELPERS
109
+ # ============================================================================
110
  def _regex_scan(docs: list[str]) -> dict:
111
+ """Run pattern banks against docs. Returns hit dicts with exact match spans."""
 
 
 
 
112
  method_hits = defaultdict(list)
113
  technique_hits = defaultdict(list)
114
  orientation_counts = {"empirical": 0, "theoretical": 0, "mixed": 0}
 
116
  for doc_idx, doc in enumerate(docs):
117
  for label, pat in METHODOLOGY_PATTERNS.items():
118
  for m in pat.finditer(doc):
119
+ method_hits[label].append({
120
+ "doc": doc_idx + 1, "match": m.group(0),
121
+ "span": [m.start(), m.end()]})
122
  for label, pat in TECHNIQUE_PATTERNS.items():
123
  for m in pat.finditer(doc):
124
+ technique_hits[label].append({
125
+ "doc": doc_idx + 1, "match": m.group(0),
126
+ "span": [m.start(), m.end()]})
127
  for orient, pat in ORIENTATION_PATTERNS.items():
128
  if pat.search(doc):
129
  orientation_counts[orient] += 1
130
 
131
  total_orient = sum(orientation_counts.values()) or 1
 
132
  return {
133
  "methods": {k: v for k, v in method_hits.items() if v},
134
  "techniques": {k: v for k, v in technique_hits.items() if v},
 
140
  "patterns_applied": {
141
  "methodology": list(METHODOLOGY_PATTERNS.keys()),
142
  "technique": list(TECHNIQUE_PATTERNS.keys()),
 
143
  },
144
  }
145
 
146
 
147
  def _regex_summary(scan: dict) -> str:
148
+ """Human-readable regex evidence injected into LLM prompts."""
149
  lines = []
150
  if scan["methods"]:
151
  lines.append("REGEX-DETECTED METHODOLOGIES:")
152
  for k, hits in scan["methods"].items():
153
+ unique = list(dict.fromkeys(h["match"] for h in hits))[:3]
154
  papers = sorted({h["doc"] for h in hits})
155
+ lines.append(f" β€’ {k} β€” matched: {unique} (papers: {papers})")
156
  if scan["techniques"]:
157
  lines.append("REGEX-DETECTED TECHNIQUES:")
158
  for k, hits in scan["techniques"].items():
159
+ unique = list(dict.fromkeys(h["match"] for h in hits))[:3]
160
  papers = sorted({h["doc"] for h in hits})
161
+ lines.append(f" β€’ {k} β€” matched: {unique} (papers: {papers})")
162
+ return "\n".join(lines) or "No regex hits found β€” rely on methodology text alone."
163
 
164
 
165
+ def _detect_journal(doi: str, title: str) -> str:
166
+ """Detect journal from DOI + title using JOURNAL_PATTERNS. Returns 'Other' if unknown."""
167
+ text = f"{doi or ''} {title or ''}"
168
+ for journal, pat in JOURNAL_PATTERNS.items():
169
+ if pat.search(text):
170
+ return journal
171
+ return "MISQ" # methodology CSV default β€” override downstream if needed
172
+
173
+
174
+ # ============================================================================
175
+ # LANGGRAPH STATE
176
+ # ============================================================================
177
  class PipelineState(TypedDict, total=False):
178
+ # ── original fields (DO NOT CHANGE) ──────────────────────────────────────
179
  filepath: str
180
  groq_key: str
181
  mistral_key: str
 
192
  refinement_log: list
193
  json_path: str
194
  error: str
195
+ # ── new fields for methodology-CSV pipeline ───────────────────────────────
196
+ methodology_filepath: str # uploaded methodology CSV path
197
+ methodology_papers: list # [{title, doi, methodology, journal, paper_idx}]
198
+ methodology_embeddings: list # SPECTER-2 embeddings (separate vector space)
199
+ comp_technique_sheets: dict # {1:Groq, 2:Mistral, 3:Gemini, 4:Consolidated}
200
+ journal_crosstab: dict # {journal: {technique: pct}}
201
+ technique_opt_log: list # improvement suggestions from optimizer
202
+
203
+
204
+ # ============================================================================
205
+ # API HELPERS (unchanged)
206
+ # ============================================================================
207
  def _parse(raw: str) -> dict:
208
  raw = raw.strip().replace("```json","").replace("```","").strip()
209
  s, e = raw.find("{"), raw.rfind("}")+1
 
214
  def _groq(client, prompt):
215
  try:
216
  r = client.chat.completions.create(model=GROQ_MODEL,
217
+ messages=[{"role":"user","content":prompt}], temperature=0.2, timeout=30)
218
  return _parse(r.choices[0].message.content)
219
  except Exception as e: logger.warning("Groq: %s", e); return {}
220
 
 
224
  r = requests.post("https://api.mistral.ai/v1/chat/completions",
225
  headers={"Authorization":f"Bearer {key}","Content-Type":"application/json"},
226
  json={"model":MISTRAL_MODEL,"messages":[{"role":"user","content":prompt}],
227
+ "temperature":0.2}, timeout=30)
228
  return _parse(r.json()["choices"][0]["message"]["content"])
229
  except Exception as e: logger.warning("Mistral: %s", e); return {}
230
 
 
245
  msg = err.get("message","") if isinstance(err,dict) else str(err)
246
  if "quota" in msg.lower() or "rate" in msg.lower():
247
  wait = min(40, 10*(attempt+1))
248
+ logger.warning("Gemini rate-limited, waiting %ds…", wait)
249
+ time.sleep(wait); continue
250
  logger.warning("Gemini attempt %d: %s", attempt+1, msg); return {}
251
  return _parse(d["candidates"][0]["content"]["parts"][0]["text"])
252
  except Exception as e:
 
254
  return {}
255
 
256
 
257
+ # ============================================================================
258
+ # ORIGINAL PROMPTS (unchanged)
259
+ # ============================================================================
260
  def _label_prompt(keyphrases, rep_docs):
261
  kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
262
  ab = " | ".join(a[:250] for a in rep_docs[:3])
 
281
  Pick the best label or synthesise a better one.
282
  Return ONLY JSON: {{"label":"...","description":"...","pacis_match":"...","confidence":0.0}}"""
283
 
 
 
 
 
284
  def _methodology_prompt(label: str, rep_docs: list[str], regex_summary: str) -> str:
285
  ab = "\n\n".join(f"Paper {i+1}: {d[:500]}" for i,d in enumerate(rep_docs[:3]))
286
  return f"""You are a research methodology auditor for the cluster: "{label}".
 
316
  "regex_rejected": ["<label2>"]
317
  }}"""
318
 
319
+ def _critic_prompt(label, description, keyphrases, rep_docs):
320
+ kp = ", ".join(k[0] if isinstance(k,tuple) else k for k in keyphrases[:5])
321
+ ab = " | ".join(d[:300] for d in rep_docs[:3])
322
+ return f"""You are a strict quality auditor for research topic labels.
323
+ CURRENT LABEL: "{label}"
324
+ CURRENT DESCRIPTION: "{description}"
325
+ KEYPHRASES: {kp}
326
+ REPRESENTATIVE ABSTRACTS: {ab}
327
+ Audit for: hallucination, vagueness, keyphrase alignment, specificity.
328
+ Return ONLY valid JSON:
329
+ {{
330
+ "refined_label": "<improved 5-8 word label>",
331
+ "refined_description": "<one sentence>",
332
+ "hallucination_detected": true/false,
333
+ "issues": ["<issue1>"],
334
+ "improvement_score": <0.0-1.0>,
335
+ "confidence": <0.0-1.0>
336
+ }}"""
337
 
338
+
339
+ # ============================================================================
340
+ # NEW: COMPUTATIONAL TECHNIQUE PROMPTS
341
+ # ============================================================================
342
+ def _comp_technique_batch_prompt(papers: list[dict], regex_hint: str) -> str:
343
  """
344
+ Prompt fed to each LLM for a batch of methodology-CSV papers.
345
+ Papers have keys: paper_idx, title, journal, methodology (text).
346
+ regex_hint is the pre-scanned regex evidence for this batch.
347
  """
348
+ batch_text = "\n\n".join(
349
+ f"PAPER {p['paper_idx']} [{p['journal']}] β€” {p['title'][:100]}\n"
350
+ f"METHODOLOGY TEXT: {p['methodology'][:800]}"
351
+ for p in papers
352
+ )
353
+ paper_ids = [p['paper_idx'] for p in papers]
354
+ return f"""You are a computational technique extractor for IS research papers.
355
+
356
+ REGEX PRE-SCAN (ground truth hints from pattern matching):
357
+ {regex_hint}
358
+
359
+ PAPERS:
360
+ {batch_text}
361
+
362
+ For EACH paper listed above ({paper_ids}), identify the computational techniques used.
363
+ A computational technique must be explicitly mentioned or clearly implied in the text.
364
+ Do NOT hallucinate β€” if a paper uses no computational technique, return empty list.
365
+
366
+ Also for each technique found across ALL papers, compute what percentage of papers in this
367
+ batch use that technique.
368
+
369
+ Return ONLY valid JSON:
370
+ {{
371
+ "per_paper": {{
372
+ "<paper_idx>": {{
373
+ "techniques": ["<technique1>", "<technique2>"],
374
+ "evidence": ["<≀12 word quote1>", "<≀12 word quote2>"],
375
+ "confidence": <0.0-1.0>
376
+ }}
377
+ }},
378
+ "batch_technique_pct": {{
379
+ "<technique_name>": <percentage_of_papers_in_batch_0-100>
380
+ }},
381
+ "dominant_technique": "<most common technique in batch>",
382
+ "no_technique_papers": [<paper_idxs with no clear computational technique>]
383
+ }}"""
384
+
385
+
386
+ def _technique_critique_prompt(technique: str, journal: str, pct_groq: float,
387
+ pct_mistral: float, pct_gemini: float,
388
+ evidence_samples: list[str]) -> str:
389
+ """Optimization critic for a single consolidated technique label."""
390
+ ev = " | ".join(evidence_samples[:3])
391
+ return f"""You are a research technique label auditor.
392
+
393
+ TECHNIQUE: "{technique}"
394
+ JOURNAL: {journal}
395
+ GROQ extracted it in {pct_groq:.0f}% of papers
396
+ MISTRAL extracted it in {pct_mistral:.0f}% of papers
397
+ GEMINI extracted it in {pct_gemini:.0f}% of papers
398
+ EVIDENCE QUOTES: {ev}
399
+
400
+ Audit:
401
+ 1. Is the technique name precise and not hallucinated?
402
+ 2. Is there inter-LLM disagreement (>15% gap) suggesting ambiguity?
403
+ 3. Should this be split into sub-techniques or merged with another?
404
+ 4. Suggest a refined canonical name if needed.
405
+
406
+ Return ONLY valid JSON:
407
+ {{
408
+ "refined_name": "<canonical technique name or same if fine>",
409
+ "is_hallucination": true/false,
410
+ "high_variance_across_llms": true/false,
411
+ "suggestion": "<one sentence improvement recommendation>",
412
+ "split_into": ["<sub-tech1>", "<sub-tech2>"],
413
+ "merge_with": "<other technique name or null>",
414
+ "confidence": <0.0-1.0>
415
+ }}"""
416
+
417
+
418
+ # ============================================================================
419
+ # CONSOLIDATION HELPERS (original + new)
420
+ # ============================================================================
421
+ def _consolidate_methodology(r1: dict, r2: dict, r3: dict, regex_scan: dict) -> dict:
422
+ """Merge Groq + Mistral + Gemini methodology responses. β‰₯2 LLM gate."""
423
  def _name_map(r, key):
424
  return {item["name"].strip().lower(): item for item in r.get(key, [])}
425
 
426
  def _merge_items(key):
427
+ maps = [_name_map(r, key) for r in [r1, r2, r3]]
428
  all_keys = set().union(*[m.keys() for m in maps])
429
  accepted, rejected = [], []
430
  for k in all_keys:
431
+ voters = [m[k] for m in maps if k in m]
432
+ n_votes = len(voters)
433
+ avg_pct = round(sum(v.get("pct",0) for v in voters) / n_votes)
434
+ papers = sorted({p for v in voters for p in v.get("papers",[])})
435
+ evidence= next((v.get("evidence","") for v in voters if v.get("evidence")), "")
436
+ row = {"name": voters[0]["name"], "pct": avg_pct, "papers": papers,
437
+ "evidence": evidence, "llm_votes": n_votes,
438
+ "agreement": "Triple" if n_votes==3 else "Two" if n_votes==2 else "Single"}
 
 
 
 
 
439
  (accepted if n_votes >= 2 else rejected).append(row)
440
  return (sorted(accepted, key=lambda x: -x["pct"]),
441
  sorted(rejected, key=lambda x: -x["pct"]))
442
 
443
+ methods_acc, methods_rej = _merge_items("methodologies")
444
  techniques_acc, techniques_rej = _merge_items("techniques")
445
 
446
  emp_avg = round(sum(r.get("empirical_pct", 0) for r in [r1,r2,r3]) / 3)
 
449
 
450
  confirmed_votes = Counter(item for r in [r1,r2,r3] for item in r.get("regex_confirmed",[]))
451
  rejected_votes = Counter(item for r in [r1,r2,r3] for item in r.get("regex_rejected",[]))
 
452
  dom_m = Counter(r.get("dominant_method","") for r in [r1,r2,r3] if r).most_common(1)
453
  dom_t = Counter(r.get("dominant_technique","") for r in [r1,r2,r3] if r).most_common(1)
454
 
455
  return {
456
+ "methodologies": methods_acc, "techniques": techniques_acc,
457
+ "rejected_methods": methods_rej, "rejected_techniques": techniques_rej,
 
 
458
  "dominant_method": dom_m[0][0] if dom_m else "β€”",
459
  "dominant_technique": dom_t[0][0] if dom_t else "β€”",
460
+ "empirical_pct": emp_avg, "theoretical_pct": theo_avg, "mixed_pct": mix_avg,
 
 
461
  "regex_confirmed_consensus": [k for k,v in confirmed_votes.items() if v>=2],
462
  "regex_rejected_consensus": [k for k,v in rejected_votes.items() if v>=2],
463
  "llm_raw": {"groq": r1, "mistral": r2, "gemini": r3},
 
465
  }
466
 
467
 
468
+ def _consolidate_comp_techniques(r1: dict, r2: dict, r3: dict,
469
+ papers: list[dict]) -> dict:
470
+ """
471
+ Consolidate per-paper technique extraction from 3 LLMs.
472
+ Rule: a technique is accepted for a paper when β‰₯2 LLMs named it.
473
+ Builds per-LLM technique % and consolidated %.
474
+ """
475
+ all_paper_ids = [str(p["paper_idx"]) for p in papers]
476
+
477
+ def _get_per_paper(resp):
478
+ return resp.get("per_paper", {})
479
+
480
+ def _get_batch_pct(resp):
481
+ return resp.get("batch_technique_pct", {})
482
+
483
+ # Per-LLM batch percentages (for LLM sheets)
484
+ pct_groq = {k.lower(): v for k,v in _get_batch_pct(r1).items()}
485
+ pct_mistral = {k.lower(): v for k,v in _get_batch_pct(r2).items()}
486
+ pct_gemini = {k.lower(): v for k,v in _get_batch_pct(r3).items()}
487
+
488
+ all_tech_keys = set(pct_groq) | set(pct_mistral) | set(pct_gemini)
489
+
490
+ # β‰₯2 LLM gate for consolidated batch %
491
+ consolidated_pct = {}
492
+ for tk in all_tech_keys:
493
+ vals = [d[tk] for d in [pct_groq, pct_mistral, pct_gemini] if tk in d]
494
+ if len(vals) >= 2:
495
+ consolidated_pct[tk] = round(sum(vals) / len(vals))
496
+
497
+ # Per-paper consolidated techniques (β‰₯2 LLMs must name the technique for that paper)
498
+ per_paper_groq = _get_per_paper(r1)
499
+ per_paper_mistral = _get_per_paper(r2)
500
+ per_paper_gemini = _get_per_paper(r3)
501
+
502
+ per_paper_consolidated = {}
503
+ for pid in all_paper_ids:
504
+ techs_groq = set(t.lower() for t in per_paper_groq.get(pid, {}).get("techniques", []))
505
+ techs_mistral = set(t.lower() for t in per_paper_mistral.get(pid,{}).get("techniques", []))
506
+ techs_gemini = set(t.lower() for t in per_paper_gemini.get(pid, {}).get("techniques", []))
507
+ # Union of all named techniques
508
+ all_named = techs_groq | techs_mistral | techs_gemini
509
+ accepted = [t for t in all_named
510
+ if sum([t in techs_groq, t in techs_mistral, t in techs_gemini]) >= 2]
511
+ per_paper_consolidated[pid] = accepted
512
+
513
+ dom_g = r1.get("dominant_technique","β€”")
514
+ dom_m = r2.get("dominant_technique","β€”")
515
+ dom_gem = r3.get("dominant_technique","β€”")
516
+ dominant = Counter([dom_g, dom_m, dom_gem]).most_common(1)
517
 
518
+ return {
519
+ "per_paper_consolidated": per_paper_consolidated,
520
+ "consolidated_pct": consolidated_pct,
521
+ "pct_groq": pct_groq,
522
+ "pct_mistral": pct_mistral,
523
+ "pct_gemini": pct_gemini,
524
+ "dominant_technique": dominant[0][0] if dominant else "β€”",
525
+ "raw": {"groq": r1, "mistral": r2, "gemini": r3},
526
+ }
527
 
528
+
529
+ # ============================================================================
530
+ # GROUNDING + CLEAN
531
+ # ============================================================================
532
  def _grounding(label, keyphrases):
533
  if not label or not keyphrases: return {"verdict":"FAIL","score":0}
534
  lt = set(re.findall(r"\b[a-z]{3,}\b", label.lower()))
 
546
  return s[:60].rsplit(" ",1)[0] if len(s)>60 else s
547
 
548
 
549
+ # ============================================================================
550
+ # ORIGINAL NODES (DO NOT CHANGE)
551
+ # ============================================================================
552
  def embed_and_cluster(state: PipelineState) -> dict:
553
  from tools import run_topic_modeling
554
  try:
 
558
  return {"error": str(e)}
559
 
560
 
 
 
 
561
  def llm_council(state: PipelineState) -> dict:
562
  td = state["topic_data"]
563
  if not td: return {"error": "No topic data"}
 
615
  "description":best.get("description",""),
616
  "pacis_match":best.get("pacis_match",""),
617
  "keyphrases":[k[0] if isinstance(k,tuple) else k for k in kps[:5]]}
 
618
  logger.info("Cluster %d β†’ %s [%s]", cid, label, agreement)
619
 
620
  total = len(sheets[4]) or 1
 
631
  pd.DataFrame(sheets[sn]).to_csv(path, index=False)
632
  sheet_paths[sn] = path
633
  with open("topics.json","w") as f: json.dump(sheets[4], f, indent=2)
 
634
  return {"interpretations":interps,"sheets":sheets,
635
  "agreement_rates":rates,"sheet_paths":sheet_paths,"json_path":"topics.json"}
636
 
637
 
 
 
 
638
  def optimization_loop(state: PipelineState) -> dict:
639
  n_opt = state.get("n_optimize", 1)
640
  if n_opt <= 1:
 
649
  for iteration in range(n_opt - 1):
650
  iter_num = iteration + 2
651
  logger.info("Optimization iteration %d / %d", iter_num, n_opt)
 
652
  for cid in sorted(interps.keys()):
653
  kps = td["keyphrases"].get(cid, [])
654
  rds = td["representative_docs"].get(cid, [])
 
682
  for cid, interp in interps.items():
683
  if cid in label_map:
684
  label_map[cid]["label"] = interp["label"]
 
685
  return {"interpretations":interps,"sheets":sheets,"refinement_log":refinement_log}
686
 
687
 
 
 
 
688
  def extract_methodology(state: PipelineState) -> dict:
689
+ """3-LLM council for cluster-level methodology (unchanged logic)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
  td = state["topic_data"]
691
  interps = state.get("interpretations", {})
692
  client = Groq(api_key=state["groq_key"], max_retries=0)
 
696
  for cid in sorted(td["keyphrases"].keys()):
697
  rds = td["representative_docs"].get(cid, [])
698
  label = interps.get(cid, {}).get("label", f"Cluster {cid}")
 
 
699
  scan = _regex_scan(rds)
700
  regex_hint = _regex_summary(scan)
701
  logger.info("Cluster %d regex: %d method hits, %d technique hits",
702
  cid, len(scan["methods"]), len(scan["techniques"]))
 
 
703
  prompt = _methodology_prompt(label, rds, regex_hint)
704
  r1 = _groq(client, prompt); time.sleep(1)
705
  r2 = _mistral(prompt, mk); time.sleep(1)
706
  r3 = _gemini(prompt, gk); time.sleep(4)
 
 
 
 
 
 
707
  consolidated = _consolidate_methodology(r1, r2, r3, scan)
708
  methodology_data[cid] = consolidated
 
709
  logger.info("Cluster %d β†’ dom_method: %s | dom_tech: %s",
710
  cid, consolidated["dominant_method"], consolidated["dominant_technique"])
 
711
  return {"methodology_data": methodology_data}
712
 
713
 
 
 
 
714
  def collect_top_papers(state: PipelineState) -> dict:
715
  td = state["topic_data"]
716
  interps = state.get("interpretations", {})
717
  top_papers = {}
 
718
  for cid in sorted(interps.keys()):
719
  rds = td["representative_docs"].get(cid, [])
720
  label = interps.get(cid, {}).get("label", f"Cluster {cid}")
 
722
  for rank, doc in enumerate(rds[:3], start=1):
723
  title_part = doc.split(". ")[0][:120] if ". " in doc else doc[:120]
724
  abstract_part = doc[len(title_part):].strip(". ")[:400]
725
+ papers.append({"rank":rank,"title":title_part,"abstract_snippet":abstract_part,
 
726
  "cluster":cid,"cluster_label":label})
727
  top_papers[cid] = papers
 
728
  return {"top_papers": top_papers}
729
 
730
 
 
 
 
731
  def build_mismatch(state: PipelineState) -> dict:
732
  from tools import build_mismatch_table
733
  td = state["topic_data"]
 
736
  return {"mismatch_table": build_mismatch_table(td["keyphrases"], labels_map)}
737
 
738
 
739
+ # ============================================================================
740
+ # NEW NODE 1: load_methodology_corpus
741
+ # ============================================================================
742
+ def load_methodology_corpus(state: PipelineState) -> dict:
743
+ """
744
+ Load the methodology CSV (title, doi, methodology).
745
+ Detect journal for each paper using JOURNAL_PATTERNS applied to doi + title.
746
+ Assigns paper_idx starting at 1.
747
+ Returns methodology_papers list ready for embedding and LLM extraction.
748
+ """
749
+ fpath = state.get("methodology_filepath")
750
+ if not fpath:
751
+ logger.info("No methodology CSV provided β€” skipping methodology pipeline.")
752
+ return {"methodology_papers": []}
753
+
754
+ df = pd.read_csv(fpath)
755
+ df.columns = df.columns.str.lower()
756
+ required = {"title","methodology"}
757
+ missing = required - set(df.columns)
758
+ if missing:
759
+ logger.warning("Methodology CSV missing columns: %s β€” skipping.", missing)
760
+ return {"methodology_papers": []}
761
+
762
+ if "doi" not in df.columns:
763
+ df["doi"] = "N/A"
764
+
765
+ papers = []
766
+ for idx, row in df.iterrows():
767
+ title = str(row.get("title","") or "")
768
+ doi = str(row.get("doi","N/A") or "N/A")
769
+ methodology= str(row.get("methodology","") or "")
770
+ journal = _detect_journal(doi, title)
771
+ papers.append({
772
+ "paper_idx": idx + 1,
773
+ "title": title,
774
+ "doi": doi,
775
+ "methodology": methodology,
776
+ "journal": journal,
777
+ })
778
+
779
+ journals_found = Counter(p["journal"] for p in papers)
780
+ logger.info("Loaded %d methodology papers. Journals: %s", len(papers), dict(journals_found))
781
+ return {"methodology_papers": papers}
782
+
783
+
784
+ # ============================================================================
785
+ # NEW NODE 2: embed_methodology_vectors
786
+ # ============================================================================
787
+ def embed_methodology_vectors(state: PipelineState) -> dict:
788
+ """
789
+ Embed methodology text as a SEPARATE vector space from the corpus.
790
+ Uses the same SPECTER-2 model but applied to methodology text only.
791
+ Embeddings stored as a list of lists for JSON-serialisability.
792
+ """
793
+ papers = state.get("methodology_papers", [])
794
+ if not papers:
795
+ return {"methodology_embeddings": []}
796
+
797
+ from sentence_transformers import SentenceTransformer
798
+ texts = [p["methodology"][:1500] for p in papers] # cap at 1500 chars
799
+ logger.info("Embedding %d methodology texts with SPECTER-2 (separate vector space)…", len(texts))
800
+ model = SentenceTransformer("allenai/specter2_base")
801
+ embeddings = model.encode(texts, show_progress_bar=True, batch_size=32)
802
+ logger.info("Methodology embeddings: %s", embeddings.shape)
803
+ return {"methodology_embeddings": embeddings.tolist()}
804
+
805
+
806
+ # ============================================================================
807
+ # NEW NODE 3: extract_comp_techniques (3-LLM Council)
808
+ # ============================================================================
809
+ def extract_comp_techniques(state: PipelineState) -> dict:
810
+ """
811
+ 3-LLM Council to extract computational techniques from methodology-CSV papers.
812
+
813
+ Pipeline per batch of BATCH_SIZE papers:
814
+ 1. Regex pre-scan β†’ TECHNIQUE_PATTERNS on methodology text
815
+ 2. Groq call β†’ per-paper techniques + batch %
816
+ 3. Mistral call β†’ per-paper techniques + batch %
817
+ 4. Gemini call β†’ per-paper techniques + batch %
818
+ 5. Consolidate β†’ β‰₯2 LLM gate per (paper, technique)
819
+
820
+ Produces 4 sheets (mirroring topic sheets):
821
+ Sheet 1 = Groq raw
822
+ Sheet 2 = Mistral raw
823
+ Sheet 3 = Gemini raw
824
+ Sheet 4 = Consolidated (β‰₯2 LLM agreement)
825
+ """
826
+ papers = state.get("methodology_papers", [])
827
+ if not papers:
828
+ return {"comp_technique_sheets": {1:[], 2:[], 3:[], 4:[]}}
829
+
830
+ client = Groq(api_key=state["groq_key"], max_retries=0)
831
+ mk, gk = state["mistral_key"], state["gemini_key"]
832
+ BATCH_SIZE = 5
833
+
834
+ sheets = {1:[], 2:[], 3:[], 4:[]}
835
+
836
+ # Accumulate consolidated per-paper techniques across batches
837
+ all_consolidated = {} # {paper_idx: [technique_names]}
838
+
839
+ for batch_start in range(0, len(papers), BATCH_SIZE):
840
+ batch = papers[batch_start: batch_start + BATCH_SIZE]
841
+ batch_texts = [p["methodology"][:1500] for p in batch]
842
+
843
+ # Step 1 β€” regex pre-scan on batch
844
+ scan = _regex_scan(batch_texts)
845
+ regex_hint = _regex_summary(scan)
846
+ logger.info("Batch %d-%d | regex: %d tech hits",
847
+ batch[0]["paper_idx"], batch[-1]["paper_idx"], len(scan["techniques"]))
848
+
849
+ # Step 2 β€” 3 LLM calls
850
+ prompt = _comp_technique_batch_prompt(batch, regex_hint)
851
+ r1 = _groq(client, prompt); time.sleep(1)
852
+ r2 = _mistral(prompt, mk); time.sleep(1)
853
+ r3 = _gemini(prompt, gk); time.sleep(4)
854
+
855
+ # Step 3 β€” consolidate
856
+ consolidated = _consolidate_comp_techniques(r1, r2, r3, batch)
857
+
858
+ # Build sheet rows β€” one row per paper per LLM
859
+ for p in batch:
860
+ pid = str(p["paper_idx"])
861
+ journal = p["journal"]
862
+ title = p["title"][:80]
863
+
864
+ def _fmt_llm(resp):
865
+ pp = resp.get("per_paper", {}).get(pid, {})
866
+ return {
867
+ "paper_idx": p["paper_idx"],
868
+ "title": title,
869
+ "journal": journal,
870
+ "techniques": ", ".join(pp.get("techniques", [])) or "β€”",
871
+ "evidence": " | ".join(pp.get("evidence", []))[:200] or "β€”",
872
+ "confidence":pp.get("confidence","β€”"),
873
+ }
874
+
875
+ sheets[1].append(_fmt_llm(r1))
876
+ sheets[2].append(_fmt_llm(r2))
877
+ sheets[3].append(_fmt_llm(r3))
878
+
879
+ con_techs = consolidated["per_paper_consolidated"].get(pid, [])
880
+ sheets[4].append({
881
+ "paper_idx": p["paper_idx"],
882
+ "title": title,
883
+ "journal": journal,
884
+ "techniques": ", ".join(con_techs) or "β€”",
885
+ "n_techniques": len(con_techs),
886
+ "dominant": consolidated["dominant_technique"],
887
+ })
888
+ all_consolidated[p["paper_idx"]] = con_techs
889
+
890
+ logger.info("Batch consolidated dominant: %s", consolidated["dominant_technique"])
891
+
892
+ # Save 4 sheets as CSV
893
+ sheet_names = {1:"tech_sheet1_groq",2:"tech_sheet2_mistral",
894
+ 3:"tech_sheet3_gemini",4:"tech_sheet4_consolidated"}
895
+ for sn, name in sheet_names.items():
896
+ pd.DataFrame(sheets[sn]).to_csv(f"{name}.csv", index=False)
897
+
898
+ # Attach per_paper_consolidated back to papers for cross-tab use
899
+ for p in papers:
900
+ p["consolidated_techniques"] = all_consolidated.get(p["paper_idx"], [])
901
+
902
+ return {
903
+ "comp_technique_sheets": sheets,
904
+ "methodology_papers": papers, # updated with consolidated_techniques
905
+ }
906
+
907
+
908
+ # ============================================================================
909
+ # NEW NODE 4: build_journal_crosstab
910
+ # ============================================================================
911
+ def build_journal_crosstab(state: PipelineState) -> dict:
912
+ """
913
+ Build a technique Γ— journal cross-tabulation.
914
+ For each journal in the methodology CSV, compute what % of papers in that
915
+ journal mention each consolidated technique.
916
+ Also produces per-LLM technique percentage tables for inter-LLM comparison.
917
+ """
918
+ papers = state.get("methodology_papers", [])
919
+ if not papers:
920
+ return {"journal_crosstab": {}}
921
+
922
+ sheets = state.get("comp_technique_sheets", {})
923
+
924
+ # --- Consolidated cross-tab ---
925
+ journal_tech_counts = defaultdict(lambda: defaultdict(int))
926
+ journal_paper_counts = defaultdict(int)
927
+
928
+ for p in papers:
929
+ journal = p["journal"]
930
+ journal_paper_counts[journal] += 1
931
+ for tech in p.get("consolidated_techniques", []):
932
+ journal_tech_counts[journal][tech.title()] += 1
933
+
934
+ journals = sorted(journal_paper_counts.keys())
935
+ all_techniques = sorted({t for j in journal_tech_counts.values() for t in j.keys()})
936
+
937
+ crosstab = {}
938
+ for journal in journals:
939
+ n = journal_paper_counts[journal] or 1
940
+ crosstab[journal] = {
941
+ tech: round(journal_tech_counts[journal].get(tech, 0) / n * 100)
942
+ for tech in all_techniques
943
+ }
944
+
945
+ # --- Per-LLM technique frequency across ALL papers ---
946
+ def _llm_tech_freq(sheet_rows: list) -> dict:
947
+ tech_count = defaultdict(int)
948
+ n_papers = len(sheet_rows) or 1
949
+ for row in sheet_rows:
950
+ raw = row.get("techniques","")
951
+ for t in (raw.split(", ") if raw and raw != "β€”" else []):
952
+ tech_count[t.strip().title()] += 1
953
+ return {t: round(c/n_papers*100) for t,c in tech_count.items()}
954
+
955
+ per_llm_freq = {
956
+ "Groq": _llm_tech_freq(sheets.get(1,[])),
957
+ "Mistral": _llm_tech_freq(sheets.get(2,[])),
958
+ "Gemini": _llm_tech_freq(sheets.get(3,[])),
959
+ }
960
+
961
+ logger.info("Journal crosstab: %d journals Γ— %d techniques",
962
+ len(journals), len(all_techniques))
963
+ return {
964
+ "journal_crosstab": {
965
+ "consolidated": crosstab,
966
+ "journals": journals,
967
+ "techniques": all_techniques,
968
+ "journal_paper_counts": dict(journal_paper_counts),
969
+ "per_llm_freq": per_llm_freq,
970
+ }
971
+ }
972
+
973
+
974
+ # ============================================================================
975
+ # NEW NODE 5: optimize_technique_labels
976
+ # ============================================================================
977
+ def optimize_technique_labels(state: PipelineState) -> dict:
978
+ """
979
+ Optimization / improvement pass for computational technique labels.
980
+ Runs Groq critic on each consolidated technique found across all journals.
981
+ Checks: hallucination, high inter-LLM variance, merge/split suggestions.
982
+ Stores improvement suggestions in technique_opt_log for display in UI.
983
+ Only applies optimisation when n_optimize > 1.
984
+ """
985
+ n_opt = state.get("n_optimize", 1)
986
+ if n_opt <= 1:
987
+ return {"technique_opt_log": []}
988
+
989
+ crosstab_data = state.get("journal_crosstab", {})
990
+ all_techniques = crosstab_data.get("techniques", [])
991
+ if not all_techniques:
992
+ return {"technique_opt_log": []}
993
+
994
+ client = Groq(api_key=state["groq_key"], max_retries=0)
995
+ per_llm = crosstab_data.get("per_llm_freq", {})
996
+ papers = state.get("methodology_papers", [])
997
+ opt_log = []
998
+
999
+ # Sample evidence quotes for each technique from methodology texts
1000
+ def _evidence_for(technique: str) -> list[str]:
1001
+ tech_lower = technique.lower()
1002
+ samples = []
1003
+ for p in papers[:30]: # cap at first 30 papers for speed
1004
+ text = p.get("methodology","")
1005
+ for pat in TECHNIQUE_PATTERNS.values():
1006
+ for m in pat.finditer(text):
1007
+ if tech_lower in m.group(0).lower() or technique.lower() in tech_lower:
1008
+ snippet = text[max(0,m.start()-40):m.end()+40].replace("\n"," ")
1009
+ samples.append(snippet[:120])
1010
+ if len(samples) >= 3:
1011
+ break
1012
+ return samples[:3]
1013
+
1014
+ for tech in all_techniques:
1015
+ pct_g = per_llm.get("Groq",{}).get(tech, 0)
1016
+ pct_m = per_llm.get("Mistral",{}).get(tech, 0)
1017
+ pct_gem = per_llm.get("Gemini",{}).get(tech, 0)
1018
+ evidence= _evidence_for(tech)
1019
+
1020
+ # Only run critique if there is meaningful inter-LLM variance or low confidence
1021
+ max_pct = max(pct_g, pct_m, pct_gem)
1022
+ min_pct = min(pct_g, pct_m, pct_gem)
1023
+ run_critique = (max_pct - min_pct) > 15 or max_pct < 20
1024
+
1025
+ if not run_critique:
1026
+ continue
1027
+
1028
+ critique = _groq(client,
1029
+ _technique_critique_prompt(tech, "All Journals", pct_g, pct_m, pct_gem, evidence))
1030
+ time.sleep(0.8)
1031
+
1032
+ if not critique:
1033
+ continue
1034
+
1035
+ opt_log.append({
1036
+ "technique": tech,
1037
+ "refined_name": critique.get("refined_name", tech),
1038
+ "is_hallucination": critique.get("is_hallucination", False),
1039
+ "high_variance": critique.get("high_variance_across_llms", False),
1040
+ "suggestion": critique.get("suggestion","β€”"),
1041
+ "split_into": ", ".join(critique.get("split_into",[]) or []) or "β€”",
1042
+ "merge_with": critique.get("merge_with","β€”") or "β€”",
1043
+ "pct_groq": pct_g,
1044
+ "pct_mistral": pct_m,
1045
+ "pct_gemini": pct_gem,
1046
+ "confidence": critique.get("confidence", 0),
1047
+ })
1048
+ logger.info("Technique opt: '%s' β†’ '%s'", tech, critique.get("refined_name",tech))
1049
+
1050
+ return {"technique_opt_log": opt_log}
1051
+
1052
+
1053
+ # ============================================================================
1054
+ # GRAPH ASSEMBLY
1055
+ # ============================================================================
1056
  def build_graph() -> StateGraph:
1057
  g = StateGraph(PipelineState)
1058
+
1059
+ # ── original nodes ───────────────────────────────────────────────────────
1060
  g.add_node("embed_and_cluster", embed_and_cluster)
1061
  g.add_node("llm_council", llm_council)
1062
  g.add_node("optimization_loop", optimization_loop)
1063
  g.add_node("extract_methodology", extract_methodology)
1064
  g.add_node("collect_top_papers", collect_top_papers)
1065
  g.add_node("build_mismatch", build_mismatch)
1066
+
1067
+ # ── new methodology-CSV nodes ─────────────────────────────────────────────
1068
+ g.add_node("load_methodology_corpus", load_methodology_corpus)
1069
+ g.add_node("embed_methodology_vectors", embed_methodology_vectors)
1070
+ g.add_node("extract_comp_techniques", extract_comp_techniques)
1071
+ g.add_node("build_journal_crosstab", build_journal_crosstab)
1072
+ g.add_node("optimize_technique_labels", optimize_technique_labels)
1073
+
1074
+ # ── original edges (unchanged) ────────────────────────────────────────────
1075
  g.set_entry_point("embed_and_cluster")
1076
  g.add_edge("embed_and_cluster", "llm_council")
1077
  g.add_edge("llm_council", "optimization_loop")
1078
  g.add_edge("optimization_loop", "extract_methodology")
1079
  g.add_edge("extract_methodology", "collect_top_papers")
1080
  g.add_edge("collect_top_papers", "build_mismatch")
1081
+
1082
+ # ── new edges: methodology CSV pipeline runs after core pipeline ──────────
1083
+ g.add_edge("build_mismatch", "load_methodology_corpus")
1084
+ g.add_edge("load_methodology_corpus", "embed_methodology_vectors")
1085
+ g.add_edge("embed_methodology_vectors", "extract_comp_techniques")
1086
+ g.add_edge("extract_comp_techniques", "build_journal_crosstab")
1087
+ g.add_edge("build_journal_crosstab", "optimize_technique_labels")
1088
+ g.add_edge("optimize_technique_labels", END)
1089
+
1090
  return g.compile()
1091
 
1092
+
1093
  pipeline_graph = build_graph()
1094
 
1095
+
1096
  def run_pipeline(filepath, groq_key, mistral_key, gemini_key,
1097
+ n_trials=50, n_optimize=1, methodology_filepath=None):
1098
+ """Convenience wrapper β€” methodology_filepath is optional."""
1099
  return pipeline_graph.invoke({
1100
+ "filepath": filepath,
1101
+ "groq_key": groq_key,
1102
+ "mistral_key": mistral_key,
1103
+ "gemini_key": gemini_key,
1104
+ "n_trials": n_trials,
1105
+ "n_optimize": n_optimize,
1106
+ "methodology_filepath": methodology_filepath,
1107
  })