anujjuna commited on
Commit
d8fd287
·
verified ·
1 Parent(s): 05df72c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +314 -427
agent.py CHANGED
@@ -1,19 +1,25 @@
1
  """
2
  agent.py
3
  --------
4
- LLM-driven topic interpretation and classification module using a 3-LLM ensemble.
 
 
 
 
 
5
  """
6
 
7
  from __future__ import annotations
8
  import json
9
  import logging
10
  import os
 
11
  import time
12
- from dataclasses import dataclass, asdict
13
  from typing import Optional
14
  import pandas as pd
 
15
  import requests
16
- import re
17
  from groq import Groq
18
 
19
  # ---------------------------------------------------------------------------
@@ -25,503 +31,384 @@ logger = logging.getLogger(__name__)
25
  # ---------------------------------------------------------------------------
26
  # Constants
27
  # ---------------------------------------------------------------------------
28
- DEFAULT_MODEL = "llama-3.1-8b-instant"
29
- MISTRAL_DEFAULT_MODEL = "mistral-small-latest"
30
- DEFAULT_TAXONOMY_CATEGORIES = [
31
- "Artificial Intelligence", "Machine Learning", "Natural Language Processing",
32
- "Computer Vision", "Information Systems", "Healthcare & Bioinformatics",
33
- "Finance & Economics", "Cybersecurity", "Human-Computer Interaction",
34
- "Robotics & Automation", "Education Technology", "Environmental Science",
 
 
 
35
  "Social Sciences", "Data Engineering", "Other",
36
  ]
37
 
38
  # ---------------------------------------------------------------------------
39
- # PAJAIS 2019 Knowledge — what the 2019 taxonomy covers vs does NOT cover
40
  # ---------------------------------------------------------------------------
41
- PAJAIS_COVERED = [
42
- "IS strategy", "IS adoption", "IS governance", "e-commerce", "enterprise systems",
43
- "ERP", "knowledge management", "decision support", "e-government", "social media IS",
44
- "IT outsourcing", "IS security", "privacy", "IS education", "mobile commerce",
45
- "business intelligence", "data analytics", "IS in healthcare (general)",
46
- "human computer interaction", "HCI", "IT project management",
47
- ]
48
-
49
- PAJAIS_NOT_COVERED = [
50
- "large language models", "LLM", "GPT", "generative AI", "RAG",
51
- "process mining", "event log", "Petri net", "conformance checking",
52
- "federated learning", "differential privacy", "DP-SGD",
53
- "fairness", "algorithmic bias", "responsible AI", "FATE", "XAI", "explainable AI",
54
- "blockchain analytics", "smart contract", "DeFi", "tokenomics",
55
- "COVID-19 IS", "pandemic informatics",
56
- "Android malware", "mobile security", "dark web", "cyber insurance",
57
- "agentic AI", "multi-agent orchestration",
58
- "transformer", "BERT", "neural topic model", "BERTopic",
59
- "recommender neural", "graph neural network", "GNN",
60
- "heterogeneous computing", "IoT analytics", "edge computing IS",
61
- "talent matching", "job-person fit", "HR analytics",
62
- ]
63
 
64
- # Rule-based NOVEL trigger — fires ONLY on specific, unambiguous compound/technical terms
65
- # that are definitively absent from PAJAIS 2019.
66
- # Deliberately narrow: single common words like "data", "model", "network", "learning",
67
- # "deep", "smart", "financial", "detection" do NOT trigger this — they exist in PAJAIS.
68
- # Only truly post-2018 or PAJAIS-absent compound terms qualify.
69
- NOVEL_REGEX_TRIGGERS = re.compile(
70
- r'\b('
71
- r'llms?|gpt[\-\s]?\d*|large\s+language\s+model|generative\s+ai|'
72
- r'federat\w*\s+learn\w*|differential\s+privac\w*|dp\-sgd|'
73
- r'process\s+mining|event\s+log|petri\s+net|conformance\s+check\w*|'
74
- r'blockchain|smart\s+contract|defi\b|tokenomic\w*|'
75
- r'malware|botnet|dark\s+web|cyber\s+insur\w*|'
76
- r'responsible\s+ai|explainab\w*\s+ai|algorithmic\s+bias|xai\b|'
77
- r'agentic\s+ai|multi.agent\s+orchest\w*|'
78
- r'graph\s+neural\s+network|gnn\b|'
79
- r'retrieval.augment\w*|prompt\s+engineer\w*|rag\b|'
80
- r'talent\s+match\w*|job.person\s+fit|'
81
- r'covid.19|pandemic\s+inform\w*'
82
- r')\b',
83
- re.IGNORECASE
84
- )
85
-
86
- def _is_deterministic_novel(keywords: list[str], samples: list[str]) -> bool:
87
- """Non-LLM rule-based check: fires only on specific unambiguous NOVEL compound terms.
88
- Generic single words (data, model, network, learning, detection) do NOT trigger this.
89
- The keyword list from BERTopic is checked word-by-word AND as joined text to catch
90
- compound matches that span two keywords."""
91
- # Check the joined keyword string (catches "process mining" split across two keywords)
92
- keyword_text = " ".join(keywords).lower()
93
- sample_text = " ".join(samples).lower()
94
- return (
95
- bool(NOVEL_REGEX_TRIGGERS.search(keyword_text)) or
96
- bool(NOVEL_REGEX_TRIGGERS.search(sample_text))
97
- )
98
 
99
- # ---------------------------------------------------------------------------
100
- # Data Classes
101
- # ---------------------------------------------------------------------------
102
  @dataclass
103
- class TopicInterpretation:
104
- """Structured interpretation for a single topic."""
105
- topic_id: int
106
- label: str
107
- category: str
108
- classification: str
 
 
 
 
 
 
 
 
 
109
  paper_count: int = 0
110
- keywords: list[str] = None
 
111
 
112
  # ---------------------------------------------------------------------------
113
- # API Clients & Calls
114
  # ---------------------------------------------------------------------------
115
  def build_groq_client(api_key: Optional[str] = None):
116
  key = api_key or os.getenv("GROQ_API_KEY")
117
  if not key:
118
- raise ValueError("No Groq API key provided.")
119
  return Groq(api_key=key, max_retries=0)
120
 
121
- def call_gemini_label(prompt: str, api_key: str) -> dict:
122
- """Call Google AI Studio (Gemini) API."""
123
- if not api_key: return {}
124
- url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={api_key}"
125
- headers = {"Content-Type": "application/json"}
126
- payload = {"contents": [{"parts": [{"text": prompt}]}], "generationConfig": {"temperature": 0.2}}
127
  try:
128
- response = requests.post(url, headers=headers, json=payload, timeout=10)
129
- data = response.json()
130
- if "error" in data or "candidates" not in data:
131
- logger.error(f"Gemini error / missing candidates. Response: {data}")
132
- return {}
133
- raw = data["candidates"][0]["content"]["parts"][0]["text"].strip()
134
- raw = raw.replace("```json", "").replace("```", "").strip()
135
- start = raw.find("{")
136
- end = raw.rfind("}") + 1
137
- if start != -1 and end != 0:
138
- raw = raw[start:end]
139
- return json.loads(raw)
140
  except Exception as e:
141
- logger.warning(f"Gemini call failed: {e}")
142
  return {}
143
 
144
- def call_mistral_label(prompt: str, api_key: str) -> dict:
145
- """Call Mistral API."""
146
- if not api_key: return {}
 
147
  try:
148
- response = requests.post(
149
  "https://api.mistral.ai/v1/chat/completions",
150
- headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
151
- json={
152
- "model": "mistral-small-latest",
153
- "messages": [{"role": "user", "content": prompt}],
154
- "temperature": 0.2,
155
- },
156
- timeout=10,
157
  )
158
- data = response.json()
159
- raw = data["choices"][0]["message"]["content"].strip()
160
- raw = raw.replace("```json", "").replace("```", "").strip()
161
- start, end = raw.find("{"), raw.rfind("}") + 1
162
- return json.loads(raw[start:end])
163
  except Exception as e:
164
- logger.warning(f"Mistral call failed: {e}")
165
  return {}
166
 
167
- def _call_llm_json(client, prompt: str, model: str) -> dict:
168
- """Call Groq API with robust JSON parsing."""
 
 
 
 
169
  try:
170
- response = client.chat.completions.create(
171
- model=model, messages=[{"role": "user", "content": prompt}], temperature=0.2, timeout=10,
172
- )
173
- raw = response.choices[0].message.content.strip()
174
- raw = raw.replace("```json", "").replace("```", "").strip()
175
- start = raw.find("{")
176
- end = raw.rfind("}") + 1
177
- if start != -1 and end != 0:
178
- raw = raw[start:end]
179
- return json.loads(raw)
180
  except Exception as e:
181
- logger.warning(f"Groq call failed: {e}")
182
  return {}
183
 
184
- # ---------------------------------------------------------------------------
185
- # Logic Helpers
186
- # ---------------------------------------------------------------------------
187
- def convert_numpy_types(obj):
188
- """Recursively convert numpy types to native Python types for JSON serialisation."""
189
- import numpy as np
190
- if isinstance(obj, dict):
191
- return {k: convert_numpy_types(v) for k, v in obj.items()}
192
- elif isinstance(obj, list):
193
- return [convert_numpy_types(v) for v in obj]
194
- elif isinstance(obj, np.integer):
195
- return int(obj)
196
- elif isinstance(obj, np.floating):
197
- return float(obj)
198
- return obj
199
 
200
- def _safe_capitalize(s: str) -> str:
201
- s = str(s or "").strip()
202
- return s[0].upper() + s[1:] if s else ""
203
-
204
- def clean_label(label: str) -> str:
205
- if not label: return ""
206
- label = label.replace("\n", " ").strip()
207
- label = " ".join(label.split())
208
- label = label.rstrip(" .")
209
- if len(label) > 60:
210
- label = label[:60].rsplit(" ", 1)[0] if " " in label[:60] else label[:60]
211
- return label.strip()
212
-
213
- def _get_keyword_overlap(label: str, keywords: list[str]) -> int:
214
- label_words = set(label.lower().split())
215
- kw_set = set(k.lower() for k in keywords)
216
- return len(label_words & kw_set)
217
-
218
- def select_best_interpretation(results: list[dict], keywords: list[str]) -> dict:
219
- valid = [r for r in results if r and "label" in r]
220
- if not valid: return {}
221
-
222
- # Majority vote on label
223
- counts = {}
224
- for r in valid:
225
- l = clean_label(r["label"]).lower()
226
- counts[l] = counts.get(l, 0) + 1
227
- for l, c in counts.items():
228
- if c >= 2:
229
- best_r = next(r for r in valid if clean_label(r["label"]).lower() == l)
230
- best_r["label"] = clean_label(best_r["label"])
231
- return best_r
232
-
233
- # Fallback: keyword overlap or shortest
234
- valid.sort(key=lambda x: (-_get_keyword_overlap(clean_label(x["label"]), keywords), len(clean_label(x["label"]))))
235
- best_r = valid[0]
236
- best_r["label"] = clean_label(best_r["label"])
237
- return best_r
238
-
239
- def _fallback_label_from_keywords(keywords: list[str], topic_id: int) -> tuple[str, str]:
240
- kw_set = set([k.lower() for k in keywords])
241
- mappings = [
242
- ({"privacy", "data", "security"}, "Digital Privacy and Security", "Cybersecurity"),
243
- ({"ai", "chatbots", "agents"}, "Conversational AI", "Artificial Intelligence"),
244
- ({"neural", "network", "deep"}, "Deep Learning Systems", "Machine Learning"),
245
- ]
246
- for trigger, label, cat in mappings:
247
- if any(t in kw_set for t in trigger): return label, cat
248
- return f"Topic study on {', '.join(keywords[:2])}", "Other"
249
 
250
  # ---------------------------------------------------------------------------
251
- # Core Logic — Prompt Builder
252
  # ---------------------------------------------------------------------------
253
- def _build_interpretation_prompt(keywords, samples, cats) -> str:
254
- pajais_covered_str = "; ".join(PAJAIS_COVERED[:10])
255
- pajais_not_str = "; ".join(PAJAIS_NOT_COVERED[:12])
256
- return f"""You are an IS research classifier. A BERTopic algorithm produced the following topic cluster from ACM TMIS papers.
 
 
 
 
 
257
 
258
- KEYWORDS: {', '.join(keywords)}
259
- REPRESENTATIVE PAPER TITLES: {' | '.join(samples[:3])}
 
 
 
 
 
260
 
261
- TASK: Generate a label and classify this topic against the PAJAIS 2019 taxonomy.
262
 
263
- PAJAIS 2019 COVERS — use MAPPED only if the topic clearly fits one of these:
264
- {pajais_covered_str}
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- PAJAIS 2019 DOES NOT COVER — use NOVEL if the topic fits here:
267
- {pajais_not_str}
268
 
269
- CLASSIFICATION RULES:
270
- - NOVEL if the topic involves: LLMs/GPT/generative AI, process mining, federated learning, differential privacy, fairness/XAI/responsible AI, blockchain analytics, COVID-19 IS, mobile malware, dark web, agentic AI, IoT analytics, talent matching, cyber insurance, or any technique that postdates 2018.
271
- - MAPPED only if it clearly fits an existing PAJAIS 2019 category listed above.
272
- - When in doubt, choose NOVEL. TMIS is a computational journal and most of its recent topics post-date the 2019 taxonomy.
273
 
274
- TAXONOMY CATEGORIES (for the taxonomy_category field only): {', '.join(cats)}
 
275
 
276
- Respond ONLY with valid JSON — no other text, no markdown fences:
277
  {{
278
- "label": "<concise 5-8 word label>",
279
- "taxonomy_category": "<one category from the list>",
280
- "classification": "MAPPED or NOVEL",
281
- "reasoning": "<one sentence explaining the MAPPED vs NOVEL decision>"
 
282
  }}"""
283
 
 
284
  # ---------------------------------------------------------------------------
285
- # Validation Method 2 Regex / Pattern-based grounding check (non-LLM)
286
  # ---------------------------------------------------------------------------
287
- def validate_label_with_regex(label: str, keywords: list[str]) -> dict:
288
- """
289
- Checks if the AI-generated label is grounded in the cluster's actual keywords.
290
- Returns a dict with overlap score, matched terms, and a PASS/FAIL verdict.
291
- This method uses only Python re — no AI involved.
292
- """
293
- if not label or not keywords:
294
- return {"verdict": "FAIL", "overlap_score": 0, "matched_terms": [], "reason": "Empty label or keywords"}
295
-
296
- # Normalise: lowercase, split on word boundaries
297
- label_tokens = set(re.findall(r'\b[a-z]{3,}\b', label.lower()))
298
- kw_tokens = set(re.findall(r'\b[a-z]{3,}\b', " ".join(keywords).lower()))
299
-
300
- # Remove common stop words that add noise
301
- noise = {"the", "and", "for", "with", "using", "based", "from", "into", "this", "that", "are"}
302
- label_tokens -= noise
303
- kw_tokens -= noise
304
-
305
- matched = list(label_tokens & kw_tokens)
306
- overlap_score = len(matched) / max(len(label_tokens), 1)
307
-
308
- # Stem-level match: check if any label token is a prefix (>=4 chars) of a keyword or vice versa
309
- stem_matches = []
310
- for lt in label_tokens:
311
- for kt in kw_tokens:
312
  if len(lt) >= 4 and (kt.startswith(lt[:4]) or lt.startswith(kt[:4])):
313
- stem_matches.append(f"{lt}≈{kt}")
 
 
 
 
 
314
 
315
- total_score = min(1.0, overlap_score + 0.15 * len(stem_matches))
316
- verdict = "PASS" if (len(matched) >= 1 or len(stem_matches) >= 1) else "FAIL"
317
-
318
- return {
319
- "verdict": verdict,
320
- "overlap_score": round(total_score, 3),
321
- "matched_terms": matched,
322
- "stem_matches": stem_matches[:5],
323
- "label_tokens": list(label_tokens),
324
- "reason": f"{len(matched)} exact + {len(stem_matches)} stem matches against {len(kw_tokens)} keyword tokens",
325
- }
326
 
327
  # ---------------------------------------------------------------------------
328
- # Core — Topic Interpretation with 3-LLM Council + dual validation
329
  # ---------------------------------------------------------------------------
330
- def interpret_topic(
331
- topic_id, keywords, samples, groq_client, mistral_key, gemini_key,
332
- paper_count, representative_docs
333
- ) -> TopicInterpretation:
334
-
335
- prompt = _build_interpretation_prompt(keywords, samples, DEFAULT_TAXONOMY_CATEGORIES)
336
-
337
- # ------------------------------------------------------------------
338
- # Step A: Deterministic non-LLM NOVEL pre-check
339
- # If keywords/samples match known NOVEL patterns, override to NOVEL
340
- # regardless of what the LLMs say. This is the non-LLM validation
341
- # method — uses only regex, no AI.
342
- # ------------------------------------------------------------------
343
- forced_novel = _is_deterministic_novel(keywords, samples)
344
- if forced_novel:
345
- logger.info(f"Topic {topic_id}: NOVEL forced by regex trigger on keywords={keywords[:4]}")
346
-
347
- # ------------------------------------------------------------------
348
- # Step B: 3-LLM Council
349
- # Call Groq (LLaMA-3.1), Mistral Small, and Gemini 2.5 Flash
350
- # independently. Three different providers = three independent votes.
351
- # ------------------------------------------------------------------
352
- raw_results = []
353
-
354
- groq_res = _call_llm_json(groq_client, prompt, DEFAULT_MODEL)
355
- raw_results.append({"llm": "Groq/LLaMA-3.1", "response": groq_res})
356
  time.sleep(1)
357
-
358
- mistral_res = call_mistral_label(prompt, mistral_key)
359
- raw_results.append({"llm": "Mistral-Small", "response": mistral_res})
360
  time.sleep(1)
 
 
361
 
362
- if gemini_key:
363
- gemini_res = call_gemini_label(prompt, gemini_key)
364
- raw_results.append({"llm": "Gemini-2.5-Flash", "response": gemini_res})
365
-
366
- results = [r["response"] for r in raw_results]
367
-
368
- # ------------------------------------------------------------------
369
- # Step C: Select best label via majority vote on label text
370
- # ------------------------------------------------------------------
371
- best = select_best_interpretation(results, keywords)
372
- if not best:
373
- l, c = _fallback_label_from_keywords(keywords, topic_id)
374
- best = {"label": l, "taxonomy_category": c, "classification": "MAPPED"}
375
-
376
- final_label = _safe_capitalize(best.get("label"))
377
-
378
- # ------------------------------------------------------------------
379
- # Step D: Classification majority vote — separate from label vote
380
- # Count NOVEL vs MAPPED votes across all 3 LLMs.
381
- # NOVEL wins if: (a) forced by regex OR (b) at least 1 LLM votes NOVEL.
382
- # Conservative toward NOVEL because PAJAIS 2019 is outdated and TMIS
383
- # publishes many post-2018 techniques with no PAJAIS home.
384
- # ------------------------------------------------------------------
385
- classification_votes = []
386
- for r in results:
387
- if r and "classification" in r:
388
- v = str(r["classification"]).upper().strip()
389
- if v in ("MAPPED", "NOVEL"):
390
- classification_votes.append(v)
391
-
392
- novel_votes = classification_votes.count("NOVEL")
393
- mapped_votes = classification_votes.count("MAPPED")
394
-
395
- # Classification decision logic:
396
- # - Regex forced (unambiguous compound NOVEL term in keywords/samples) → always NOVEL
397
- # - LLM majority (2 or more of 3 LLMs vote NOVEL) → NOVEL
398
- # - Single LLM vote for NOVEL + 2 for MAPPED → MAPPED (majority wins)
399
- # - All 3 vote MAPPED → MAPPED
400
- # This gives ~40-60% NOVEL as expected for TMIS vs PAJAIS 2019 comparison.
401
- if forced_novel or novel_votes >= 2:
402
- final_classification = "NOVEL"
403
- else:
404
- final_classification = "MAPPED"
405
 
406
- logger.info(
407
- f"Topic {topic_id} classification: NOVEL_votes={novel_votes}, "
408
- f"MAPPED_votes={mapped_votes}, regex_forced={forced_novel} {final_classification}"
409
- )
 
410
 
411
- # ------------------------------------------------------------------
412
- # Step E: Build council vote evidence for UI display
413
- # Each LLM's label, category, classification, and reasoning is stored
414
- # so the UI can show per-topic agreement/disagreement transparently.
415
- # ------------------------------------------------------------------
416
- council_votes = []
417
- for r in raw_results:
418
- resp = r["response"]
419
- council_votes.append({
420
- "llm": r["llm"],
421
- "label": clean_label(resp.get("label", "")) if resp else "—",
422
- "category": resp.get("taxonomy_category", "—") if resp else "—",
423
- "classification": resp.get("classification", "—") if resp else "—",
424
- "reasoning": resp.get("reasoning", "—") if resp else "—",
425
- })
426
-
427
- # ------------------------------------------------------------------
428
- # Step F: Regex grounding check on the final label
429
- # Verifies the label tokens are grounded in actual cluster keywords.
430
- # Catches hallucinated labels (confident-sounding but disconnected
431
- # from the underlying data). Pure regex — no AI involved.
432
- # ------------------------------------------------------------------
433
- regex_validation = validate_label_with_regex(final_label, keywords)
434
- logger.info(
435
- f"Topic {topic_id} label grounding: {regex_validation['verdict']} "
436
- f"(score={regex_validation['overlap_score']}, matched={regex_validation['matched_terms']})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  )
438
 
439
- # ------------------------------------------------------------------
440
- # Build the final TopicInterpretation object
441
- # ------------------------------------------------------------------
442
- interp = TopicInterpretation(
443
- topic_id=topic_id,
444
- label=final_label,
445
- category=_safe_capitalize(best.get("taxonomy_category")),
446
- classification=final_classification,
447
- paper_count=paper_count,
448
- keywords=keywords,
449
- )
450
 
451
- # Attach validation evidence as dynamic attributes (serialised manually in run_agent)
452
- interp.council_votes = council_votes
453
- interp.regex_validation = regex_validation
454
- interp.novel_forced_by_regex = forced_novel
455
- interp.classification_votes = {"NOVEL": novel_votes, "MAPPED": mapped_votes}
 
456
 
457
- return interp
458
 
459
  # ---------------------------------------------------------------------------
460
- # Run Agent — orchestrates all topics and writes outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  # ---------------------------------------------------------------------------
462
  def run_agent(
463
- topic_results,
464
- groq_key,
465
- mistral_key,
466
- gemini_key,
467
- output_json="topics.json",
468
- output_csv="topics.csv",
469
  ) -> dict:
470
-
471
  client = build_groq_client(groq_key)
472
- res = topic_results["documents"]
473
 
474
- num_clusters = len([t for t in set(res["topics"]) if t != -1])
475
- num_topics = len(res["topic_keywords"])
476
- print(f"Final cluster count: {num_clusters}")
477
- print(f"Final topic count: {num_topics}")
478
- if num_clusters != num_topics:
479
- logger.error(f"CONSISTENCY WARNING: {num_clusters} clusters != {num_topics} topics")
480
 
 
481
  interpretations = {}
482
- for i, (tid, kw_pairs) in enumerate(res["topic_keywords"].items()):
483
- interp = interpret_topic(
484
- tid,
485
- [w for w, _ in kw_pairs],
486
- res["representative_docs"].get(tid, []),
487
- client,
488
- mistral_key,
489
- gemini_key,
490
- res["topic_freq"].get(tid, 0),
491
- res["representative_docs"].get(tid, []),
 
 
492
  )
493
- interpretations[tid] = interp
494
- logger.info(f"Interpreted {tid}: {interp.label} [{interp.classification}]")
495
-
496
- # Build serialisable list — include all validation evidence
497
- interp_list = []
498
- for i in interpretations.values():
499
- d = asdict(i)
500
- # asdict() only captures @dataclass fields; add dynamic attributes manually
501
- d["council_votes"] = getattr(i, "council_votes", [])
502
- d["regex_validation"] = getattr(i, "regex_validation", {})
503
- d["novel_forced_by_regex"] = getattr(i, "novel_forced_by_regex", False)
504
- d["classification_votes"] = getattr(i, "classification_votes", {})
505
- interp_list.append(d)
506
-
507
- clean_data = convert_numpy_types(interp_list)
508
 
 
 
509
  with open(output_json, "w") as f:
510
- json.dump(clean_data, f, indent=2)
511
-
512
- df = pd.DataFrame(clean_data)
513
  if not df.empty:
514
- df["keywords"] = df["keywords"].apply(
515
- lambda x: ", ".join(x) if isinstance(x, list) else str(x)
516
- )
 
517
  df.to_csv(output_csv, index=False)
518
 
519
- return {
520
- "interpretations": interpretations,
521
- "json_path": output_json,
522
- "csv_path": output_csv,
523
- }
524
-
525
-
526
- if __name__ == "__main__":
527
- pass
 
1
  """
2
  agent.py
3
  --------
4
+ LLM Council labelling module 3.5).
5
+
6
+ Three independent LLMs label each cluster, producing Sheets 1–3.
7
+ Sheet 4 consolidates with Triple / Two / Single agreement tags.
8
+ Disagreement clusters get a fourth-round defence prompt.
9
+ Labels not grounded in keyphrases are rejected.
10
  """
11
 
12
  from __future__ import annotations
13
  import json
14
  import logging
15
  import os
16
+ import re
17
  import time
18
+ from dataclasses import dataclass, field, asdict
19
  from typing import Optional
20
  import pandas as pd
21
+ import numpy as np
22
  import requests
 
23
  from groq import Groq
24
 
25
  # ---------------------------------------------------------------------------
 
31
  # ---------------------------------------------------------------------------
32
  # Constants
33
  # ---------------------------------------------------------------------------
34
+ GROQ_MODEL = "llama-3.1-8b-instant"
35
+ MISTRAL_MODEL = "mistral-small-latest"
36
+
37
+ DEFAULT_TAXONOMY = [
38
+ "Artificial Intelligence", "Machine Learning",
39
+ "Natural Language Processing", "Computer Vision",
40
+ "Information Systems", "Healthcare & Bioinformatics",
41
+ "Finance & Economics", "Cybersecurity",
42
+ "Human-Computer Interaction", "Robotics & Automation",
43
+ "Education Technology", "Environmental Science",
44
  "Social Sciences", "Data Engineering", "Other",
45
  ]
46
 
47
  # ---------------------------------------------------------------------------
48
+ # Data classes
49
  # ---------------------------------------------------------------------------
50
+ @dataclass
51
+ class LLMVote:
52
+ """One LLM's response for one cluster."""
53
+ llm_name: str
54
+ label: str = ""
55
+ description: str = ""
56
+ pacis_match: str = ""
57
+ confidence: float = 0.0
58
+ raw: dict = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
61
  @dataclass
62
+ class ClusterInterpretation:
63
+ """Consolidated interpretation for a single cluster."""
64
+ cluster_id: int
65
+ final_label: str = ""
66
+ final_description: str = ""
67
+ final_pacis_match: str = ""
68
+ final_confidence: float = 0.0
69
+ agreement: str = "" # Triple / Two / Single
70
+ sheet1: dict = field(default_factory=dict)
71
+ sheet2: dict = field(default_factory=dict)
72
+ sheet3: dict = field(default_factory=dict)
73
+ defence: dict = field(default_factory=dict) # 4th-round if needed
74
+ keyphrases: list = field(default_factory=list)
75
+ strong_count: int = 0
76
+ weak_count: int = 0
77
  paper_count: int = 0
78
+ grounding_check: dict = field(default_factory=dict)
79
+
80
 
81
  # ---------------------------------------------------------------------------
82
+ # API Clients
83
  # ---------------------------------------------------------------------------
84
  def build_groq_client(api_key: Optional[str] = None):
85
  key = api_key or os.getenv("GROQ_API_KEY")
86
  if not key:
87
+ raise ValueError("No Groq API key.")
88
  return Groq(api_key=key, max_retries=0)
89
 
90
+
91
+ def _call_groq(client, prompt: str) -> dict:
 
 
 
 
92
  try:
93
+ r = client.chat.completions.create(
94
+ model=GROQ_MODEL,
95
+ messages=[{"role": "user", "content": prompt}],
96
+ temperature=0.2, timeout=15,
97
+ )
98
+ return _parse_json(r.choices[0].message.content)
 
 
 
 
 
 
99
  except Exception as e:
100
+ logger.warning("Groq failed: %s", e)
101
  return {}
102
 
103
+
104
+ def _call_mistral(prompt: str, api_key: str) -> dict:
105
+ if not api_key:
106
+ return {}
107
  try:
108
+ r = requests.post(
109
  "https://api.mistral.ai/v1/chat/completions",
110
+ headers={"Authorization": f"Bearer {api_key}",
111
+ "Content-Type": "application/json"},
112
+ json={"model": MISTRAL_MODEL,
113
+ "messages": [{"role": "user", "content": prompt}],
114
+ "temperature": 0.2},
115
+ timeout=15,
 
116
  )
117
+ return _parse_json(r.json()["choices"][0]["message"]["content"])
 
 
 
 
118
  except Exception as e:
119
+ logger.warning("Mistral failed: %s", e)
120
  return {}
121
 
122
+
123
+ def _call_gemini(prompt: str, api_key: str) -> dict:
124
+ if not api_key:
125
+ return {}
126
+ url = (f"https://generativelanguage.googleapis.com/v1beta/models/"
127
+ f"gemini-2.5-flash:generateContent?key={api_key}")
128
  try:
129
+ r = requests.post(url,
130
+ headers={"Content-Type": "application/json"},
131
+ json={"contents": [{"parts": [{"text": prompt}]}],
132
+ "generationConfig": {"temperature": 0.2}},
133
+ timeout=15)
134
+ data = r.json()
135
+ if "candidates" not in data:
136
+ return {}
137
+ raw = data["candidates"][0]["content"]["parts"][0]["text"]
138
+ return _parse_json(raw)
139
  except Exception as e:
140
+ logger.warning("Gemini failed: %s", e)
141
  return {}
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ def _parse_json(raw: str) -> dict:
145
+ raw = raw.strip().replace("```json", "").replace("```", "").strip()
146
+ s, e = raw.find("{"), raw.rfind("}") + 1
147
+ if s != -1 and e > 0:
148
+ raw = raw[s:e]
149
+ try:
150
+ return json.loads(raw)
151
+ except Exception:
152
+ return {}
153
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  # ---------------------------------------------------------------------------
156
+ # Prompt builders
157
  # ---------------------------------------------------------------------------
158
+ def _build_label_prompt(keyphrases: list, rep_abstracts: list) -> str:
159
+ kp_str = ", ".join(k if isinstance(k, str) else k[0]
160
+ for k in keyphrases[:5])
161
+ abs_str = " | ".join(a[:300] for a in rep_abstracts[:3])
162
+ return f"""You are a research-topic classifier.
163
+ A SPECTER-2 + HDBSCAN pipeline produced a topic cluster.
164
+
165
+ KEYPHRASES: {kp_str}
166
+ REPRESENTATIVE ABSTRACTS (truncated): {abs_str}
167
 
168
+ Return ONLY valid JSON (no markdown, no other text):
169
+ {{
170
+ "label": "<concise 5-8 word topic label>",
171
+ "description": "<one-sentence description of the topic>",
172
+ "pacis_match": "<closest PAJAIS 2019 category, or NOVEL if none>",
173
+ "confidence": <0.0-1.0 float>
174
+ }}"""
175
 
 
176
 
177
+ def _build_defence_prompt(
178
+ keyphrases: list,
179
+ rep_abstracts: list,
180
+ votes: list[dict],
181
+ ) -> str:
182
+ kp_str = ", ".join(k if isinstance(k, str) else k[0]
183
+ for k in keyphrases[:5])
184
+ abs_str = " | ".join(a[:300] for a in rep_abstracts[:3])
185
+ v_str = "\n".join(
186
+ f" LLM {i+1}: label=\"{v.get('label','?')}\", "
187
+ f"pacis=\"{v.get('pacis_match','?')}\""
188
+ for i, v in enumerate(votes)
189
+ )
190
+ return f"""You are a research-topic adjudicator resolving a labelling disagreement.
191
 
192
+ KEYPHRASES: {kp_str}
193
+ REPRESENTATIVE ABSTRACTS: {abs_str}
194
 
195
+ Three LLMs proposed different labels:
196
+ {v_str}
 
 
197
 
198
+ Your task: pick the single best label from the three, or synthesise a
199
+ better one. Justify your choice in one sentence.
200
 
201
+ Return ONLY valid JSON:
202
  {{
203
+ "label": "<best 5-8 word label>",
204
+ "description": "<one sentence>",
205
+ "pacis_match": "<PAJAIS category or NOVEL>",
206
+ "confidence": <0.0-1.0>,
207
+ "reasoning": "<one sentence justification>"
208
  }}"""
209
 
210
+
211
  # ---------------------------------------------------------------------------
212
+ # Grounding checkreject labels not supported by keyphrases (§3.5)
213
  # ---------------------------------------------------------------------------
214
+ def grounding_check(label: str, keyphrases: list) -> dict:
215
+ """Non-LLM regex check: label tokens must overlap keyphrases."""
216
+ if not label or not keyphrases:
217
+ return {"verdict": "FAIL", "score": 0, "matched": []}
218
+ label_toks = set(re.findall(r"\b[a-z]{3,}\b", label.lower()))
219
+ kp_toks = set()
220
+ for kp in keyphrases:
221
+ phrase = kp if isinstance(kp, str) else kp[0]
222
+ kp_toks.update(re.findall(r"\b[a-z]{3,}\b", phrase.lower()))
223
+ noise = {"the", "and", "for", "with", "using", "based", "from", "that",
224
+ "are", "this", "into", "its"}
225
+ label_toks -= noise
226
+ kp_toks -= noise
227
+ matched = list(label_toks & kp_toks)
228
+ # stem-level
229
+ stems = []
230
+ for lt in label_toks:
231
+ for kt in kp_toks:
 
 
 
 
 
 
 
232
  if len(lt) >= 4 and (kt.startswith(lt[:4]) or lt.startswith(kt[:4])):
233
+ stems.append(f"{lt}≈{kt}")
234
+ score = min(1.0, len(matched) / max(len(label_toks), 1)
235
+ + 0.15 * len(stems))
236
+ verdict = "PASS" if (matched or stems) else "FAIL"
237
+ return {"verdict": verdict, "score": round(score, 3),
238
+ "matched": matched, "stems": stems[:5]}
239
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  # ---------------------------------------------------------------------------
242
+ # Core — interpret one cluster via 3-LLM council (§3.5)
243
  # ---------------------------------------------------------------------------
244
+ def interpret_cluster(
245
+ cluster_id: int,
246
+ keyphrases: list,
247
+ rep_docs: list,
248
+ strong: int,
249
+ weak: int,
250
+ groq_client,
251
+ mistral_key: str,
252
+ gemini_key: str,
253
+ ) -> ClusterInterpretation:
254
+
255
+ prompt = _build_label_prompt(keyphrases, rep_docs)
256
+
257
+ # Sheet 1 — Groq / LLaMA-3.1
258
+ s1 = _call_groq(groq_client, prompt)
 
 
 
 
 
 
 
 
 
 
 
259
  time.sleep(1)
260
+ # Sheet 2 — Mistral
261
+ s2 = _call_mistral(prompt, mistral_key)
 
262
  time.sleep(1)
263
+ # Sheet 3 — Gemini
264
+ s3 = _call_gemini(prompt, gemini_key)
265
 
266
+ votes = [s1, s2, s3]
267
+ valid = [v for v in votes if v and "label" in v]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ # --- Sheet 4: consolidate agreement ---
270
+ labels_lower = [_clean(v.get("label", "")).lower() for v in valid]
271
+ counts = {}
272
+ for l in labels_lower:
273
+ counts[l] = counts.get(l, 0) + 1
274
 
275
+ best_label = ""
276
+ agreement = "Single"
277
+ defence = {}
278
+
279
+ if any(c >= 3 for c in counts.values()):
280
+ agreement = "Triple"
281
+ winner = max(counts, key=counts.get)
282
+ best_label = next(v["label"] for v in valid
283
+ if _clean(v["label"]).lower() == winner)
284
+ elif any(c >= 2 for c in counts.values()):
285
+ agreement = "Two"
286
+ winner = max(counts, key=counts.get)
287
+ best_label = next(v["label"] for v in valid
288
+ if _clean(v["label"]).lower() == winner)
289
+ else:
290
+ agreement = "Single"
291
+ # Fourth-round defence prompt (§3.5)
292
+ defence_prompt = _build_defence_prompt(keyphrases, rep_docs, votes)
293
+ defence = _call_groq(groq_client, defence_prompt)
294
+ if defence and "label" in defence:
295
+ best_label = defence["label"]
296
+ elif valid:
297
+ best_label = valid[0]["label"]
298
+
299
+ best_label = _clean(best_label)
300
+
301
+ # Grounding check — reject if not supported by keyphrases
302
+ gc = grounding_check(best_label, keyphrases)
303
+ if gc["verdict"] == "FAIL" and valid:
304
+ # Fall back to most keyphrase-grounded label
305
+ scored = [(v, len(set(re.findall(r"\b[a-z]{3,}\b",
306
+ v.get("label", "").lower()))
307
+ & set(re.findall(r"\b[a-z]{3,}\b",
308
+ " ".join(k if isinstance(k, str) else k[0]
309
+ for k in keyphrases).lower()))))
310
+ for v in valid]
311
+ scored.sort(key=lambda x: -x[1])
312
+ best_label = _clean(scored[0][0]["label"])
313
+ gc = grounding_check(best_label, keyphrases)
314
+ logger.info("Cluster %d: label rejected by grounding, "
315
+ "fell back to '%s'", cluster_id, best_label)
316
+
317
+ # Best metadata
318
+ best_v = next((v for v in valid
319
+ if _clean(v.get("label", "")).lower()
320
+ == best_label.lower()), valid[0] if valid else {})
321
+
322
+ return ClusterInterpretation(
323
+ cluster_id=cluster_id,
324
+ final_label=best_label,
325
+ final_description=best_v.get("description", ""),
326
+ final_pacis_match=best_v.get("pacis_match", ""),
327
+ final_confidence=best_v.get("confidence", 0.0),
328
+ agreement=agreement,
329
+ sheet1=s1, sheet2=s2, sheet3=s3,
330
+ defence=defence,
331
+ keyphrases=[k if isinstance(k, str) else k[0]
332
+ for k in keyphrases[:5]],
333
+ strong_count=strong,
334
+ weak_count=weak,
335
+ paper_count=strong + weak,
336
+ grounding_check=gc,
337
  )
338
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
+ def _clean(s: str) -> str:
341
+ s = str(s or "").replace("\n", " ").strip()
342
+ s = " ".join(s.split())
343
+ if len(s) > 60:
344
+ s = s[:60].rsplit(" ", 1)[0] if " " in s[:60] else s[:60]
345
+ return s.rstrip(" .")
346
 
 
347
 
348
  # ---------------------------------------------------------------------------
349
+ # Numpy-safe serialisation
350
+ # ---------------------------------------------------------------------------
351
+ def _convert(obj):
352
+ if isinstance(obj, dict):
353
+ return {k: _convert(v) for k, v in obj.items()}
354
+ if isinstance(obj, list):
355
+ return [_convert(v) for v in obj]
356
+ if isinstance(obj, (np.integer,)):
357
+ return int(obj)
358
+ if isinstance(obj, (np.floating,)):
359
+ return float(obj)
360
+ return obj
361
+
362
+
363
+ # ---------------------------------------------------------------------------
364
+ # Run agent — orchestrate all clusters
365
  # ---------------------------------------------------------------------------
366
  def run_agent(
367
+ topic_results: dict,
368
+ groq_key: str,
369
+ mistral_key: str,
370
+ gemini_key: str,
371
+ output_json: str = "topics.json",
372
+ output_csv: str = "topics.csv",
373
  ) -> dict:
 
374
  client = build_groq_client(groq_key)
 
375
 
376
+ labels_list = topic_results["labels"]
377
+ keyphrases = topic_results["keyphrases"]
378
+ rep_docs = topic_results["representative_docs"]
379
+ membership = topic_results["membership"]
 
 
380
 
381
+ cluster_ids = sorted(keyphrases.keys())
382
  interpretations = {}
383
+
384
+ for cid in cluster_ids:
385
+ sw = membership.get(cid, {"strong": 0, "weak": 0})
386
+ interp = interpret_cluster(
387
+ cluster_id=cid,
388
+ keyphrases=keyphrases.get(cid, []),
389
+ rep_docs=rep_docs.get(cid, []),
390
+ strong=sw["strong"],
391
+ weak=sw["weak"],
392
+ groq_client=client,
393
+ mistral_key=mistral_key,
394
+ gemini_key=gemini_key,
395
  )
396
+ interpretations[cid] = interp
397
+ logger.info("Cluster %d %s [%s] (%d strong, %d weak)",
398
+ cid, interp.final_label, interp.agreement,
399
+ interp.strong_count, interp.weak_count)
 
 
 
 
 
 
 
 
 
 
 
400
 
401
+ # Serialise
402
+ records = [_convert(asdict(i)) for i in interpretations.values()]
403
  with open(output_json, "w") as f:
404
+ json.dump(records, f, indent=2)
405
+ df = pd.DataFrame(records)
 
406
  if not df.empty:
407
+ for col in ["sheet1", "sheet2", "sheet3", "defence",
408
+ "keyphrases", "grounding_check"]:
409
+ if col in df.columns:
410
+ df[col] = df[col].apply(str)
411
  df.to_csv(output_csv, index=False)
412
 
413
+ return dict(interpretations=interpretations,
414
+ json_path=output_json, csv_path=output_csv)