anujjuna commited on
Commit
0a39f3a
·
verified ·
1 Parent(s): 0a624b3

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +187 -737
agent.py CHANGED
@@ -1,29 +1,19 @@
1
  """
2
  agent.py
3
  --------
4
- LLM-driven topic interpretation and classification module.
5
-
6
- For each BERTopic-discovered topic this agent:
7
- 1. Generates a concise, human-readable label.
8
- 2. Assigns the topic to a taxonomy category.
9
- 3. Classifies the topic as MAPPED or NOVEL.
10
-
11
- It then cross-compares title-derived and abstract-derived topics and writes:
12
- - taxonomy_map.json – full classification for every topic
13
- - comparison.csv – side-by-side diff of title vs. abstract topics
14
  """
15
 
16
  from __future__ import annotations
17
-
18
  import json
19
  import logging
20
  import os
21
  import time
22
  from dataclasses import dataclass, asdict
23
  from typing import Optional
24
-
25
  import pandas as pd
26
  import requests
 
27
  from groq import Groq
28
 
29
  # ---------------------------------------------------------------------------
@@ -36,781 +26,241 @@ logger = logging.getLogger(__name__)
36
  # Constants
37
  # ---------------------------------------------------------------------------
38
  DEFAULT_MODEL = "llama-3.1-8b-instant"
39
- MISTRAL_DEFAULT_MODEL = "mistral-small-latest" # --- Dual LLM Validation ---
40
  DEFAULT_TAXONOMY_CATEGORIES = [
41
- "Artificial Intelligence",
42
- "Machine Learning",
43
- "Natural Language Processing",
44
- "Computer Vision",
45
- "Information Systems",
46
- "Healthcare & Bioinformatics",
47
- "Finance & Economics",
48
- "Cybersecurity",
49
- "Human-Computer Interaction",
50
- "Robotics & Automation",
51
- "Education Technology",
52
- "Environmental Science",
53
- "Social Sciences",
54
- "Data Engineering",
55
- "Other",
56
  ]
57
 
58
- CLASSIFICATION_OPTIONS = ("MAPPED", "NOVEL")
59
-
60
  # ---------------------------------------------------------------------------
61
  # Data Classes
62
  # ---------------------------------------------------------------------------
63
  @dataclass
64
  class TopicInterpretation:
65
  """Structured interpretation for a single topic."""
66
- source: str
67
  topic_id: int
68
- keywords: list[str]
69
  label: str
70
- taxonomy_category: str
71
  classification: str
72
- reasoning: str
73
- # --- Dual LLM Validation ---
74
- validation_status: str = "PENDING" # AGREED | DISAGREED | REVIEW_REQUIRED
75
- confidence: str = "MEDIUM" # HIGH | MEDIUM
76
- label_source: str = "groq" # groq | fallback
77
-
78
-
79
- @dataclass
80
- class ComparisonRow:
81
- """One row in the title-vs-abstract comparison table."""
82
- topic_id: int
83
- title_label: str
84
- title_category: str
85
- title_classification: str
86
- abstract_label: str
87
- abstract_category: str
88
- abstract_classification: str
89
- overlap_keywords: str # comma-separated shared keywords
90
- difference_note: str # LLM-generated note on differences
91
-
92
 
93
  # ---------------------------------------------------------------------------
94
- # OpenAI Client
95
  # ---------------------------------------------------------------------------
96
- def build_openai_client(api_key: Optional[str] = None):
97
  key = api_key or os.getenv("GROQ_API_KEY")
98
  if not key:
99
- raise ValueError(
100
- "No Groq API key provided. "
101
- "Pass api_key= or set the GROQ_API_KEY environment variable."
102
- )
103
  return Groq(api_key=key, max_retries=0)
104
 
105
-
106
-
107
-
108
- # ---------------------------------------------------------------------------
109
- # Helpers
110
- # ---------------------------------------------------------------------------
111
- def _ensure_string(x) -> str:
112
- """Safely convert any input (list, None, etc.) to a string."""
113
- if isinstance(x, list):
114
- return " ".join(str(i) for i in x)
115
- if x is None:
116
- return ""
117
- return str(x)
118
-
119
- def _safe_capitalize(s: str) -> str:
120
- """Capitalize only the first letter, keeping the rest as is (unlike .capitalize())."""
121
- s = _ensure_string(s).strip()
122
- if not s:
123
- return ""
124
- return s[0].upper() + s[1:]
125
-
126
- # ---------------------------------------------------------------------------
127
- # Prompt Builders
128
- # ---------------------------------------------------------------------------
129
- def _build_interpretation_prompt(
130
- keywords: list[str],
131
- sample_texts: list[str],
132
- taxonomy_categories: list[str],
133
- ) -> str:
134
- """Return the user prompt for labelling and classifying a single topic."""
135
- kw_str = ", ".join(keywords)
136
- samples_str = "\n".join(f" - {t}" for t in sample_texts[:5])
137
- cats_str = "\n".join(f" - {c}" for c in taxonomy_categories)
138
-
139
- return f"""You are an expert research analyst. A topic modelling algorithm has produced the following topic.
140
-
141
- TOP KEYWORDS:
142
- {kw_str}
143
-
144
- SAMPLE DOCUMENTS FOR THIS TOPIC:
145
- {samples_str}
146
-
147
- AVAILABLE TAXONOMY CATEGORIES:
148
- {cats_str}
149
-
150
- Your task:
151
- 1. Write a concise label (≤8 words) that captures the essence of this topic.
152
- 2. Assign it to ONE category from the list above. Use "Other" only as a last resort.
153
- 3. Classify it as MAPPED (fits an existing, well-established research area) or NOVEL (represents an emerging or cross-disciplinary theme not well-represented in standard taxonomies).
154
- 4. Provide one sentence of reasoning.
155
-
156
- Respond ONLY with valid JSON in exactly this schema – no markdown fences:
157
- {{
158
- "label": "<short label>",
159
- "taxonomy_category": "<one of the listed categories>",
160
- "classification": "MAPPED" | "NOVEL",
161
- "reasoning": "<one sentence>"
162
- }}"""
163
-
164
-
165
- # --- Dual LLM Validation ---
166
- def _fallback_label_from_keywords(keywords: list[str], topic_id: int) -> tuple[str, str]:
167
- """Deterministic keyword-to-label heuristic fallback."""
168
- kw_set = set([k.lower() for k in keywords])
169
-
170
- # Mapping heuristics
171
- mappings = [
172
- ({"privacy", "data", "security", "protection"}, "Digital Privacy and Security Risks", "Cybersecurity"),
173
- ({"ai", "chatbots", "agents", "conversational", "interaction", "assistant"}, "Conversational AI and Human Interaction", "Artificial Intelligence"),
174
- ({"gaming", "players", "video", "games", "engagement"}, "Gaming and User Engagement Patterns", "Human-Computer Interaction"),
175
- ({"vr", "virtual", "immersive", "training", "reality"}, "Virtual Reality and Immersive Training", "Robotics & Automation"),
176
- ({"patient", "healthcare", "medical", "clinical", "hospital"}, "Healthcare Technology and Patient Care", "Healthcare & Bioinformatics"),
177
- ({"shopping", "commerce", "purchase", "ecommerce", "consumer"}, "E-commerce and Consumer Behavior", "Finance & Economics"),
178
- ({"internet", "addiction", "adolescents", "youth", "behavior"}, "Internet Addiction and Adolescent Behavior", "Social Sciences"),
179
- ({"gamification", "learning", "education", "student", "classroom"}, "Gamification in Learning and Interaction", "Education Technology"),
180
- ({"neural", "network", "deep", "learning", "cnn", "transformer"}, "Deep Learning Architectures", "Machine Learning"),
181
- ({"graph", "knowledge", "relational", "embedding"}, "Knowledge Graphs and Relational Data", "Data Engineering"),
182
- ]
183
-
184
- for trigger_kws, fallback_label, fallback_cat in mappings:
185
- if any(tk in kw_set for tk in trigger_kws):
186
- return fallback_label, fallback_cat
187
-
188
- # Generic fallback if no specific rule matches
189
- main_kws = ", ".join(_safe_capitalize(k) for k in keywords[:2])
190
- label = f"Study on {', '.join(keywords[:3])}"
191
- return label, "Other"
192
-
193
- def _build_validation_prompt(keywords, groq_label, groq_category):
194
- return f"""
195
- You are reviewing topic classification for research papers.
196
-
197
- Keywords: {', '.join(keywords[:8])}
198
- Proposed label: {groq_label}
199
- Proposed category: {groq_category}
200
-
201
- Instructions:
202
- - If label and category reasonably match the keywords → say YES
203
- - If there is a clear mismatch → say NO
204
- - Small wording differences are OK
205
- - Be balanced: do not be too strict or too lenient
206
-
207
- Respond ONLY in JSON:
208
- {{
209
- "AGREEMENT": "YES" or "NO",
210
- "CONFIDENCE": "HIGH", "MEDIUM", or "LOW",
211
- "REASON": "<short explanation>"
212
- }}
213
- """
214
-
215
-
216
- def _call_mistral_validation(
217
- mistral_api_key,
218
- keywords,
219
- groq_label,
220
- groq_category,
221
- model="mistral-small-latest",
222
- ):
223
- if not mistral_api_key:
224
  return {}
225
 
226
- prompt = _build_validation_prompt(keywords, groq_label, groq_category)
227
-
 
228
  try:
229
  response = requests.post(
230
  "https://api.mistral.ai/v1/chat/completions",
231
- headers={
232
- "Authorization": f"Bearer {mistral_api_key}",
233
- "Content-Type": "application/json",
234
- },
235
  json={
236
- "model": model,
237
  "messages": [{"role": "user", "content": prompt}],
238
- "temperature": 0.1,
239
  },
240
- timeout=20,
241
  )
242
-
243
  data = response.json()
244
  raw = data["choices"][0]["message"]["content"].strip()
245
-
246
  raw = raw.replace("```json", "").replace("```", "").strip()
247
  start, end = raw.find("{"), raw.rfind("}") + 1
248
  return json.loads(raw[start:end])
249
-
250
  except Exception as e:
251
- logger.warning(f"Mistral validation failed: {e}")
252
- return {} # Ensure fallback logic triggers correctly
253
-
254
-
255
- def _build_comparison_prompt(
256
- topic_id: int,
257
- title_interp: TopicInterpretation,
258
- abstract_interp: TopicInterpretation,
259
- ) -> str:
260
- """Return the user prompt for comparing a title topic to an abstract topic."""
261
- return f"""You are comparing two topic representations for Topic ID {topic_id}.
262
-
263
- TITLE-BASED TOPIC
264
- Label : {title_interp.label}
265
- Category : {title_interp.taxonomy_category}
266
- Class : {title_interp.classification}
267
- Keywords : {', '.join(title_interp.keywords)}
268
-
269
- ABSTRACT-BASED TOPIC
270
- Label : {abstract_interp.label}
271
- Category : {abstract_interp.taxonomy_category}
272
- Class : {abstract_interp.classification}
273
- Keywords : {', '.join(abstract_interp.keywords)}
274
-
275
- In one concise sentence, describe the most meaningful difference (or similarity) between these two topic representations.
276
- Respond with ONLY the sentence – no JSON, no markdown."""
277
-
278
-
279
- # ---------------------------------------------------------------------------
280
- # LLM Calls
281
- # ---------------------------------------------------------------------------
282
- def _call_llm_json(
283
- client,
284
- prompt: str,
285
- model: str,
286
- retries: int = 1,
287
- backoff: float = 1.0,
288
- ) -> dict:
289
- """
290
- Call the OpenAI chat completion endpoint and parse the response as JSON.
291
-
292
- Parameters
293
- ----------
294
- client : OpenAI
295
- prompt : str
296
- model : str
297
- retries : int
298
- backoff : float
299
- Seconds to wait between retries (exponential).
300
-
301
- Returns
302
- -------
303
- dict
304
- Parsed JSON response.
305
- """
306
- for attempt in range(1, retries + 1):
307
- try:
308
- response = client.chat.completions.create(
309
- model=model,
310
- messages=[{"role": "user", "content": prompt}],
311
- temperature=0.2,
312
- timeout=8,
313
- )
314
- raw = response.choices[0].message.content.strip()
315
- raw = raw.replace("```json", "").replace("```", "").strip()
316
- start = raw.find("{")
317
- end = raw.rfind("}") + 1
318
- if start == -1 or end == 0:
319
- raise ValueError("No JSON object found in response")
320
- return json.loads(raw[start:end])
321
-
322
- except (json.JSONDecodeError, ValueError) as exc:
323
- logger.warning("Attempt %d – Parse error: %s", attempt, exc)
324
- except Exception as exc:
325
- logger.warning("Attempt %d – API error: %s", attempt, exc)
326
- if "rate limit" in str(exc).lower():
327
- time.sleep(1)
328
- if attempt < retries:
329
- time.sleep(0.5)
330
-
331
- return {}
332
-
333
 
334
- def _call_llm_text(
335
- client,
336
- prompt: str,
337
- model: str,
338
- ) -> str:
339
- """Call the OpenAI endpoint and return plain text."""
340
  try:
341
  response = client.chat.completions.create(
342
- model=model,
343
- messages=[{"role": "user", "content": prompt}],
344
- temperature=0.3,
345
  )
346
- return response.choices[0].message.content.strip()
347
- except Exception as exc:
348
- logger.warning("LLM text call failed: %s", exc)
349
- return ""
350
-
351
-
352
- # --- Dual LLM Validation ---
353
-
354
-
355
-
356
- def _decide_validation(groq_category: str, mistral_result: dict) -> tuple[str, str]:
357
- """
358
- Decision logic – Groq is authoritative, Mistral is validator.
359
- """
360
-
361
- if not mistral_result:
362
- return "AGREED", "LOW"
363
-
364
- agreement = mistral_result.get("AGREEMENT", "NO").upper()
365
- confidence = mistral_result.get("CONFIDENCE", "MEDIUM").upper()
366
- suggested = mistral_result.get("SUGGESTED_CATEGORY", groq_category).strip()
367
-
368
- # Extract root categories
369
- groq_root = groq_category.split("&")[0].strip().lower()
370
- suggested_root = suggested.split("&")[0].strip().lower()
371
-
372
- # ✅ Case 1: Agreement
373
- if agreement == "YES":
374
- return "AGREED", confidence
375
-
376
- # ✅ Case 2: Disagreement (handle smartly)
377
- if agreement == "NO":
378
-
379
- # Strong disagreement → flag clearly
380
- if confidence == "HIGH":
381
- if groq_root != suggested_root:
382
- return "REVIEW_REQUIRED", "HIGH"
383
- return "DISAGREED", "HIGH"
384
-
385
- # Medium disagreement → partial trust
386
- if confidence == "MEDIUM":
387
- if groq_root != suggested_root:
388
- return "REVIEW_REQUIRED", "MEDIUM"
389
- return "DISAGREED", "MEDIUM"
390
-
391
- # Low confidence → be lenient
392
- return "AGREED", "LOW"
393
-
394
- return "AGREED", "LOW"
395
-
396
 
397
  # ---------------------------------------------------------------------------
398
- # Core Interpretation
399
  # ---------------------------------------------------------------------------
400
- def interpret_topic(
401
- client,
402
- source: str,
403
- topic_id: int,
404
- keywords: list[str],
405
- sample_texts: list[str],
406
- taxonomy_categories: list[str],
407
- model: str = DEFAULT_MODEL,
408
- mistral_api_key: Optional[str] = None,
409
- mistral_model: str = MISTRAL_DEFAULT_MODEL,
410
- ) -> TopicInterpretation:
411
- # Step 1: Groq generates label / category / classification
412
- prompt = _build_interpretation_prompt(keywords, sample_texts, taxonomy_categories)
413
- data = _call_llm_json(client, prompt, model, retries=2)
414
 
415
- label_source = "groq"
416
- if not data:
417
- # Fallback to heuristic if Groq fails
418
- fallback_label, fallback_cat = _fallback_label_from_keywords(keywords, topic_id)
419
- label = fallback_label
420
- category = fallback_cat
421
- classification = "MAPPED"
422
- reasoning = "Generated via keyword heuristics due to LLM timeout."
423
- label_source = "fallback"
424
- else:
425
- label = _ensure_string(data.get("label", "Unknown Topic"))
426
- category = _ensure_string(data.get("taxonomy_category", "Other"))
427
- classification = _ensure_string(data.get("classification", "MAPPED")).upper()
428
- reasoning = _ensure_string(data.get("reasoning", ""))
429
-
430
- if label == "Unknown Topic":
431
- fallback_label, fallback_cat = _fallback_label_from_keywords(keywords, topic_id)
432
- label = fallback_label
433
- category = fallback_cat
434
- label_source = "fallback"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
- # Final normalization and safe capitalization
437
- label = _safe_capitalize(label)
438
- category = _safe_capitalize(category)
 
 
 
 
 
 
 
439
 
440
- if classification not in CLASSIFICATION_OPTIONS:
441
- classification = "MAPPED"
 
 
 
 
 
 
442
 
443
- # Step 2 & 3: Mistral validates – Groq stays authoritative
444
- mistral_result = _call_mistral_validation(
445
- mistral_api_key, keywords, label, category, mistral_model
446
- )
447
- validation_status, confidence = _decide_validation(category, mistral_result)
 
 
448
 
449
- logger.info(
450
- "[%s] Topic %d → '%s' (%s) | %s | val=%s conf=%s",
451
- source, topic_id, label, label_source, category, validation_status, confidence,
452
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  return TopicInterpretation(
454
- source=source,
455
  topic_id=topic_id,
456
- keywords=keywords,
457
- label=label,
458
- taxonomy_category=category,
459
- classification=classification,
460
- reasoning=reasoning,
461
- validation_status=validation_status,
462
- confidence=confidence,
463
- label_source=label_source
464
  )
465
 
466
-
467
- def interpret_all_topics(
468
- client,
469
- source: str,
470
- topic_keywords: dict[int, list[tuple[str, float]]],
471
- topic_docs: dict[int, list[str]],
472
- taxonomy_categories: list[str] = DEFAULT_TAXONOMY_CATEGORIES,
473
- model: str = DEFAULT_MODEL,
474
- mistral_api_key: Optional[str] = None, # --- Dual LLM Validation ---
475
- mistral_model: str = MISTRAL_DEFAULT_MODEL,
476
- ) -> dict[int, TopicInterpretation]:
477
- """Interpret every topic for a given source with optional Mistral validation."""
478
- interpretations: dict[int, TopicInterpretation] = {}
479
-
480
- MAX_TOPICS = 200 # Increased for fuller comparison
481
- selected_topics = dict(list(topic_keywords.items())[:MAX_TOPICS])
482
-
483
- for topic_id, kw_pairs in selected_topics.items():
484
- keywords = [w for w, _ in kw_pairs]
485
- samples = topic_docs.get(topic_id, [])[:5]
486
-
487
  interp = interpret_topic(
488
- client=client,
489
- source=source,
490
- topic_id=topic_id,
491
- keywords=keywords,
492
- sample_texts=samples,
493
- taxonomy_categories=taxonomy_categories,
494
- model=model,
495
- mistral_api_key=mistral_api_key,
496
- mistral_model=mistral_model,
497
- )
498
-
499
- interpretations[topic_id] = interp
500
- time.sleep(2) # API rate limiting
501
-
502
- return interpretations
503
-
504
-
505
- # ---------------------------------------------------------------------------
506
- # Cross-Source Comparison
507
- # ---------------------------------------------------------------------------
508
- def _get_overlap_keywords(a: TopicInterpretation, b: TopicInterpretation) -> list[str]:
509
- """Return keywords shared between two topic interpretations."""
510
- return list(set(a.keywords) & set(b.keywords))
511
-
512
-
513
- def compare_topics(
514
- client,
515
- title_interpretations: dict[int, TopicInterpretation],
516
- abstract_interpretations: dict[int, TopicInterpretation],
517
- model: str = DEFAULT_MODEL,
518
- ) -> list[ComparisonRow]:
519
- """
520
- Pair topics that share the same topic_id across title and abstract sources
521
- and produce a comparison row for each shared ID.
522
-
523
- Parameters
524
- ----------
525
- client : OpenAI
526
- title_interpretations : dict[int, TopicInterpretation]
527
- abstract_interpretations : dict[int, TopicInterpretation]
528
- model : str
529
-
530
- Returns
531
- -------
532
- list[ComparisonRow]
533
- """
534
- shared_ids = sorted(
535
- set(title_interpretations) & set(abstract_interpretations)
536
- )
537
- rows: list[ComparisonRow] = []
538
-
539
- for tid in shared_ids:
540
- t_interp = title_interpretations[tid]
541
- a_interp = abstract_interpretations[tid]
542
- overlap = _get_overlap_keywords(t_interp, a_interp)
543
- diff_note = _call_llm_text(
544
- client,
545
- _build_comparison_prompt(tid, t_interp, a_interp),
546
- model,
547
  )
548
- if not diff_note or len(diff_note.strip()) < 5:
549
- diff_note = "Minor or no significant difference"
550
-
551
- rows.append(
552
- ComparisonRow(
553
- topic_id=tid,
554
- title_label=t_interp.label,
555
- title_category=t_interp.taxonomy_category,
556
- title_classification=t_interp.classification,
557
- abstract_label=a_interp.label,
558
- abstract_category=a_interp.taxonomy_category,
559
- abstract_classification=a_interp.classification,
560
- overlap_keywords=", ".join(overlap) if overlap else "none",
561
- difference_note=diff_note,
562
- )
563
- )
564
- logger.info("Compared topic %d across sources.", tid)
565
-
566
- return rows
567
-
568
-
569
- # ---------------------------------------------------------------------------
570
- # Output Writers
571
- # ---------------------------------------------------------------------------
572
- def build_taxonomy_map(
573
- title_interpretations: dict[int, TopicInterpretation],
574
- abstract_interpretations: dict[int, TopicInterpretation],
575
- ) -> dict:
576
- """
577
- Merge title and abstract interpretations into a single taxonomy map dict.
578
-
579
- Returns
580
- -------
581
- dict
582
- Structured taxonomy map ready for JSON serialisation.
583
- """
584
- def _serialize(interps: dict[int, TopicInterpretation]) -> list[dict]:
585
- return [asdict(v) for v in interps.values()]
586
-
587
- return {
588
- "titles": _serialize(title_interpretations),
589
- "abstracts": _serialize(abstract_interpretations),
590
- }
591
-
592
-
593
- def save_taxonomy_map(taxonomy_map: dict, output_path: str = "taxonomy_map.json") -> None:
594
- """
595
- Write the taxonomy map to a JSON file.
596
-
597
- Parameters
598
- ----------
599
- taxonomy_map : dict
600
- output_path : str
601
- """
602
- with open(output_path, "w", encoding="utf-8") as fh:
603
- json.dump(taxonomy_map, fh, indent=2, ensure_ascii=False)
604
- logger.info("Taxonomy map saved → %s", output_path)
605
-
606
-
607
- def save_comparison_csv(
608
- comparison_rows: list[ComparisonRow],
609
- output_path: str = "comparison.csv",
610
- ) -> None:
611
- """
612
- Write the comparison rows to a CSV file.
613
-
614
- Parameters
615
- ----------
616
- comparison_rows : list[ComparisonRow]
617
- output_path : str
618
- """
619
- if not comparison_rows:
620
- logger.warning("No comparison rows to save.")
621
- return
622
-
623
- df = pd.DataFrame([asdict(r) for r in comparison_rows])
624
- df.to_csv(output_path, index=False)
625
- logger.info("Comparison CSV saved → %s", output_path)
626
-
627
-
628
- # ---------------------------------------------------------------------------
629
- # Helper: Build topic_docs mapping from BERTopic output
630
- # ---------------------------------------------------------------------------
631
- def build_topic_docs_map(
632
- raw_texts: list[str],
633
- topic_assignments: list[int],
634
- ) -> dict[int, list[str]]:
635
- """
636
- Group raw documents by their assigned topic ID.
637
-
638
- Parameters
639
- ----------
640
- raw_texts : list[str]
641
- Original (unprocessed) text documents.
642
- topic_assignments : list[int]
643
- Topic ID assigned to each document by BERTopic (parallel to raw_texts).
644
-
645
- Returns
646
- -------
647
- dict[int, list[str]]
648
- Mapping of topic_id → list of documents belonging to that topic.
649
- """
650
- mapping: dict[int, list[str]] = {}
651
- for doc, tid in zip(raw_texts, topic_assignments):
652
- if tid == -1:
653
- continue
654
- mapping.setdefault(tid, []).append(doc)
655
- return mapping
656
-
657
-
658
- # ---------------------------------------------------------------------------
659
- # High-Level Pipeline
660
- # ---------------------------------------------------------------------------
661
- def run_agent(
662
- title_topic_keywords: dict[int, list[tuple[str, float]]],
663
- abstract_topic_keywords: dict[int, list[tuple[str, float]]],
664
- title_topic_assignments: list[int],
665
- abstract_topic_assignments: list[int],
666
- raw_titles: list[str],
667
- raw_abstracts: list[str],
668
- api_key: Optional[str] = None,
669
- model: str = DEFAULT_MODEL,
670
- taxonomy_categories: list[str] = DEFAULT_TAXONOMY_CATEGORIES,
671
- taxonomy_map_path: str = "taxonomy_map.json",
672
- comparison_csv_path: str = "comparison.csv",
673
- mistral_api_key: Optional[str] = None, # --- Dual LLM Validation ---
674
- mistral_model: str = MISTRAL_DEFAULT_MODEL,
675
- ) -> dict:
676
- """
677
- End-to-end agent pipeline:
678
- 1. Interpret title topics via LLM
679
- 2. Interpret abstract topics via LLM
680
- 3. Compare cross-source topics
681
- 4. Write taxonomy_map.json and comparison.csv
682
-
683
- Parameters
684
- ----------
685
- title_topic_keywords : dict
686
- Output of tools.extract_topics()["topic_keywords"] for titles.
687
- abstract_topic_keywords : dict
688
- Output of tools.extract_topics()["topic_keywords"] for abstracts.
689
- title_topic_assignments : list[int]
690
- Output of tools.extract_topics()["topics"] for titles.
691
- abstract_topic_assignments : list[int]
692
- Output of tools.extract_topics()["topics"] for abstracts.
693
- raw_titles : list[str]
694
- Original (unprocessed) title strings.
695
- raw_abstracts : list[str]
696
- Original (unprocessed) abstract strings.
697
- api_key : str, optional
698
- OpenAI API key (falls back to OPENAI_API_KEY env var).
699
- model : str
700
- OpenAI model to use (default gpt-4o-mini).
701
- taxonomy_categories : list[str]
702
- Taxonomy buckets the LLM may assign topics to.
703
- taxonomy_map_path : str
704
- Output path for taxonomy_map.json.
705
- comparison_csv_path : str
706
- Output path for comparison.csv.
707
-
708
- Returns
709
- -------
710
- dict with keys
711
- title_interpretations – dict[int, TopicInterpretation]
712
- abstract_interpretations – dict[int, TopicInterpretation]
713
- comparison_rows – list[ComparisonRow]
714
- taxonomy_map – dict (JSON-serialisable)
715
- """
716
- client = build_openai_client(api_key)
717
- mistral_api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
718
-
719
- # --- Build raw-text lookup maps ---
720
- title_docs_map = build_topic_docs_map(raw_titles, title_topic_assignments)
721
- abstract_docs_map = build_topic_docs_map(raw_abstracts, abstract_topic_assignments)
722
-
723
- # --- Interpret topics ---
724
- logger.info("Interpreting TITLE topics …")
725
- title_interps = interpret_all_topics(
726
- client=client,
727
- source="titles",
728
- topic_keywords=title_topic_keywords,
729
- topic_docs=title_docs_map,
730
- taxonomy_categories=taxonomy_categories,
731
- model=model,
732
- mistral_api_key=mistral_api_key,
733
- mistral_model=mistral_model,
734
- )
735
-
736
- logger.info("Interpreting ABSTRACT topics …")
737
- abstract_interps = interpret_all_topics(
738
- client=client,
739
- source="abstracts",
740
- topic_keywords=abstract_topic_keywords,
741
- topic_docs=abstract_docs_map,
742
- taxonomy_categories=taxonomy_categories,
743
- model=model,
744
- mistral_api_key=mistral_api_key,
745
- mistral_model=mistral_model,
746
- )
747
-
748
- # --- Compare ---
749
- logger.info("Comparing title vs. abstract topics …")
750
- comparison_rows = compare_topics(client, title_interps, abstract_interps, model)
751
-
752
- # --- Persist ---
753
- taxonomy_map = build_taxonomy_map(title_interps, abstract_interps)
754
- save_taxonomy_map(taxonomy_map, taxonomy_map_path)
755
- save_comparison_csv(comparison_rows, comparison_csv_path)
756
-
757
- return {
758
- "title_interpretations": title_interps,
759
- "abstract_interpretations": abstract_interps,
760
- "comparison_rows": comparison_rows,
761
- "taxonomy_map": taxonomy_map,
762
- }
763
-
764
-
765
- # ---------------------------------------------------------------------------
766
- # CLI Entry Point
767
- # ---------------------------------------------------------------------------
768
- if __name__ == "__main__":
769
- """
770
- Demo / smoke-test: runs agent on synthetic topic data.
771
- Set OPENAI_API_KEY in your environment before running.
772
- """
773
- DEMO_TITLE_KEYWORDS: dict[int, list[tuple[str, float]]] = {
774
- 0: [("neural", 0.9), ("network", 0.85), ("deep", 0.8), ("learning", 0.75), ("training", 0.7)],
775
- 1: [("blockchain", 0.88), ("transaction", 0.82), ("ledger", 0.78), ("consensus", 0.74), ("crypto", 0.7)],
776
- }
777
- DEMO_ABSTRACT_KEYWORDS: dict[int, list[tuple[str, float]]] = {
778
- 0: [("deep", 0.91), ("model", 0.87), ("classification", 0.82), ("accuracy", 0.78), ("dataset", 0.74)],
779
- 1: [("distributed", 0.86), ("blockchain", 0.81), ("smart", 0.77), ("contract", 0.73), ("peer", 0.68)],
780
- }
781
-
782
- sample_titles = [
783
- "Deep Learning for Image Classification",
784
- "Neural Networks in Healthcare",
785
- "Blockchain and Distributed Ledger Technology",
786
- "Smart Contracts in Finance",
787
- ]
788
- sample_abstracts = [
789
- "We propose a deep learning model achieving state-of-the-art accuracy on benchmark datasets.",
790
- "A convolutional network trained for medical image classification.",
791
- "This paper surveys blockchain consensus mechanisms and distributed ledger architectures.",
792
- "We implement smart contracts for automated financial transactions on a public blockchain.",
793
- ]
794
-
795
- title_assignments = [0, 0, 1, 1]
796
- abstract_assignments = [0, 0, 1, 1]
797
-
798
- results = run_agent(
799
- title_topic_keywords=DEMO_TITLE_KEYWORDS,
800
- abstract_topic_keywords=DEMO_ABSTRACT_KEYWORDS,
801
- title_topic_assignments=title_assignments,
802
- abstract_topic_assignments=abstract_assignments,
803
- raw_titles=sample_titles,
804
- raw_abstracts=sample_abstracts,
805
- taxonomy_map_path="taxonomy_map.json",
806
- comparison_csv_path="comparison.csv",
807
- )
808
-
809
- print("\n=== Taxonomy Map (titles) ===")
810
- for interp in results["taxonomy_map"]["titles"]:
811
- print(f" [{interp['topic_id']}] {interp['label']} | {interp['taxonomy_category']} | {interp['classification']}")
812
 
813
- print("\n=== Comparison Rows ===")
814
- for row in results["comparison_rows"]:
815
- print(f" Topic {row.topic_id}: '{row.title_label}' vs '{row.abstract_label}'")
816
- print(f" Note: {row.difference_note}")
 
1
  """
2
  agent.py
3
  --------
4
+ LLM-driven topic interpretation and classification module using a 3-LLM ensemble.
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  from __future__ import annotations
 
8
  import json
9
  import logging
10
  import os
11
  import time
12
  from dataclasses import dataclass, asdict
13
  from typing import Optional
 
14
  import pandas as pd
15
  import requests
16
+ import re
17
  from groq import Groq
18
 
19
  # ---------------------------------------------------------------------------
 
26
  # Constants
27
  # ---------------------------------------------------------------------------
28
  DEFAULT_MODEL = "llama-3.1-8b-instant"
29
+ MISTRAL_DEFAULT_MODEL = "mistral-small-latest"
30
  DEFAULT_TAXONOMY_CATEGORIES = [
31
+ "Artificial Intelligence", "Machine Learning", "Natural Language Processing",
32
+ "Computer Vision", "Information Systems", "Healthcare & Bioinformatics",
33
+ "Finance & Economics", "Cybersecurity", "Human-Computer Interaction",
34
+ "Robotics & Automation", "Education Technology", "Environmental Science",
35
+ "Social Sciences", "Data Engineering", "Other",
 
 
 
 
 
 
 
 
 
 
36
  ]
37
 
 
 
38
  # ---------------------------------------------------------------------------
39
  # Data Classes
40
  # ---------------------------------------------------------------------------
41
  @dataclass
42
  class TopicInterpretation:
43
  """Structured interpretation for a single topic."""
 
44
  topic_id: int
 
45
  label: str
46
+ category: str
47
  classification: str
48
+ paper_count: int = 0
49
+ keywords: list[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # ---------------------------------------------------------------------------
52
+ # API Clients & Calls
53
  # ---------------------------------------------------------------------------
54
+ def build_groq_client(api_key: Optional[str] = None):
55
  key = api_key or os.getenv("GROQ_API_KEY")
56
  if not key:
57
+ raise ValueError("No Groq API key provided.")
 
 
 
58
  return Groq(api_key=key, max_retries=0)
59
 
60
+ def call_gemini_label(prompt: str, api_key: str) -> dict:
61
+ """Call Google AI Studio (Gemini) API."""
62
+ if not api_key: return {}
63
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key={api_key}"
64
+ headers = {"Content-Type": "application/json"}
65
+ payload = {"contents": [{"parts": [{"text": prompt}]}], "generationConfig": {"temperature": 0.2}}
66
+ try:
67
+ response = requests.post(url, headers=headers, json=payload, timeout=10)
68
+ data = response.json()
69
+ if "error" in data or "candidates" not in data:
70
+ logger.error(f"Gemini error / missing candidates. Response: {data}")
71
+ return {}
72
+ raw = data["candidates"][0]["content"]["parts"][0]["text"].strip()
73
+ raw = raw.replace("```json", "").replace("```", "").strip()
74
+ start = raw.find("{")
75
+ end = raw.rfind("}") + 1
76
+ if start != -1 and end != 0:
77
+ raw = raw[start:end]
78
+ return json.loads(raw)
79
+ except Exception as e:
80
+ logger.warning(f"Gemini call failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  return {}
82
 
83
+ def call_mistral_label(prompt: str, api_key: str) -> dict:
84
+ """Call Mistral API."""
85
+ if not api_key: return {}
86
  try:
87
  response = requests.post(
88
  "https://api.mistral.ai/v1/chat/completions",
89
+ headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
 
 
 
90
  json={
91
+ "model": "mistral-small-latest",
92
  "messages": [{"role": "user", "content": prompt}],
93
+ "temperature": 0.2,
94
  },
95
+ timeout=10,
96
  )
 
97
  data = response.json()
98
  raw = data["choices"][0]["message"]["content"].strip()
 
99
  raw = raw.replace("```json", "").replace("```", "").strip()
100
  start, end = raw.find("{"), raw.rfind("}") + 1
101
  return json.loads(raw[start:end])
 
102
  except Exception as e:
103
+ logger.warning(f"Mistral call failed: {e}")
104
+ return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ def _call_llm_json(client, prompt: str, model: str) -> dict:
107
+ """Call Groq API with robust JSON parsing."""
 
 
 
 
108
  try:
109
  response = client.chat.completions.create(
110
+ model=model, messages=[{"role": "user", "content": prompt}], temperature=0.2, timeout=10,
 
 
111
  )
112
+ raw = response.choices[0].message.content.strip()
113
+ raw = raw.replace("```json", "").replace("```", "").strip()
114
+ start = raw.find("{")
115
+ end = raw.rfind("}") + 1
116
+ if start != -1 and end != 0:
117
+ raw = raw[start:end]
118
+ return json.loads(raw)
119
+ except Exception as e:
120
+ logger.warning(f"Groq call failed: {e}")
121
+ return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # ---------------------------------------------------------------------------
124
+ # Logic Helpers
125
  # ---------------------------------------------------------------------------
126
+ def convert_numpy_types(obj):
127
+ """Recursively convert numpy types to native Python types for JSON serialisation."""
128
+ import numpy as np
129
+ if isinstance(obj, dict):
130
+ return {k: convert_numpy_types(v) for k, v in obj.items()}
131
+ elif isinstance(obj, list):
132
+ return [convert_numpy_types(v) for v in obj]
133
+ elif isinstance(obj, np.integer):
134
+ return int(obj)
135
+ elif isinstance(obj, np.floating):
136
+ return float(obj)
137
+ return obj
 
 
138
 
139
+ def _safe_capitalize(s: str) -> str:
140
+ s = str(s or "").strip()
141
+ return s[0].upper() + s[1:] if s else ""
142
+
143
+ def clean_label(label: str) -> str:
144
+ if not label: return ""
145
+ label = label.replace("\n", " ").strip()
146
+ label = " ".join(label.split())
147
+ label = label.rstrip(" .")
148
+ if len(label) > 60:
149
+ label = label[:60].rsplit(" ", 1)[0] if " " in label[:60] else label[:60]
150
+ return label.strip()
151
+
152
+ def _get_keyword_overlap(label: str, keywords: list[str]) -> int:
153
+ label_words = set(label.lower().split())
154
+ kw_set = set(k.lower() for k in keywords)
155
+ return len(label_words & kw_set)
156
+
157
+ def select_best_interpretation(results: list[dict], keywords: list[str]) -> dict:
158
+ valid = [r for r in results if r and "label" in r]
159
+ if not valid: return {}
160
+
161
+ # Majority vote
162
+ counts = {}
163
+ for r in valid:
164
+ l = clean_label(r["label"]).lower()
165
+ counts[l] = counts.get(l, 0) + 1
166
+ for l, c in counts.items():
167
+ if c >= 2:
168
+ best_r = next(r for r in valid if clean_label(r["label"]).lower() == l)
169
+ best_r["label"] = clean_label(best_r["label"])
170
+ return best_r
171
+
172
+ # Fallback: keyword overlap or shortest
173
+ valid.sort(key=lambda x: (-_get_keyword_overlap(clean_label(x["label"]), keywords), len(clean_label(x["label"]))))
174
+ best_r = valid[0]
175
+ best_r["label"] = clean_label(best_r["label"])
176
+ return best_r
177
 
178
+ def _fallback_label_from_keywords(keywords: list[str], topic_id: int) -> tuple[str, str]:
179
+ kw_set = set([k.lower() for k in keywords])
180
+ mappings = [
181
+ ({"privacy", "data", "security"}, "Digital Privacy and Security", "Cybersecurity"),
182
+ ({"ai", "chatbots", "agents"}, "Conversational AI", "Artificial Intelligence"),
183
+ ({"neural", "network", "deep"}, "Deep Learning Systems", "Machine Learning"),
184
+ ]
185
+ for trigger, label, cat in mappings:
186
+ if any(t in kw_set for t in trigger): return label, cat
187
+ return f"Topic study on {', '.join(keywords[:2])}", "Other"
188
 
189
+ # ---------------------------------------------------------------------------
190
+ # Core Logic
191
+ # ---------------------------------------------------------------------------
192
+ def _build_interpretation_prompt(keywords, samples, cats) -> str:
193
+ return f"""A topic modelling algorithm produced this topic.
194
+ KEYWORDS: {', '.join(keywords)}
195
+ SAMPLES: {' | '.join(samples[:3])}
196
+ CATEGORIES: {', '.join(cats)}
197
 
198
+ Respond ONLY in JSON:
199
+ {{
200
+ "label": "<8 words label>",
201
+ "taxonomy_category": "<one of the categories>",
202
+ "classification": "MAPPED" | "NOVEL",
203
+ "reasoning": "<one sentence>"
204
+ }}"""
205
 
206
+ def interpret_topic(topic_id, keywords, samples, groq_client, mistral_key, gemini_key, paper_count, representative_docs) -> TopicInterpretation:
207
+ prompt = _build_interpretation_prompt(keywords, samples, DEFAULT_TAXONOMY_CATEGORIES)
208
+
209
+ # Ensemble — Gemini key will be None if rate-limited by caller
210
+ results = []
211
+ results.append(_call_llm_json(groq_client, prompt, DEFAULT_MODEL))
212
+ time.sleep(1)
213
+ results.append(call_mistral_label(prompt, mistral_key))
214
+ time.sleep(1)
215
+ if gemini_key:
216
+ results.append(call_gemini_label(prompt, gemini_key))
217
+
218
+ best = select_best_interpretation(results, keywords)
219
+ if not best:
220
+ l, c = _fallback_label_from_keywords(keywords, topic_id)
221
+ best = {"label": l, "taxonomy_category": c, "classification": "MAPPED"}
222
+
223
  return TopicInterpretation(
 
224
  topic_id=topic_id,
225
+ label=_safe_capitalize(best.get("label")),
226
+ category=_safe_capitalize(best.get("taxonomy_category")),
227
+ classification=best.get("classification", "MAPPED").upper(),
228
+ paper_count=paper_count,
229
+ keywords=keywords
 
 
 
230
  )
231
 
232
+ def run_agent(topic_results, groq_key, mistral_key, gemini_key, output_json="topics.json", output_csv="topics.csv") -> dict:
233
+ client = build_groq_client(groq_key)
234
+ res = topic_results["documents"]
235
+
236
+ num_clusters = len([t for t in set(res["topics"]) if t != -1])
237
+ num_topics = len(res["topic_keywords"])
238
+ print(f"Final cluster count: {num_clusters}")
239
+ print(f"Final topic count: {num_topics}")
240
+ if num_clusters != num_topics:
241
+ logger.error(f"CONSISTENCY WARNING: {num_clusters} clusters != {num_topics} topics")
242
+
243
+ interpretations = {}
244
+ for i, (tid, kw_pairs) in enumerate(res["topic_keywords"].items()):
245
+ # Full 3-LLM council for every topic (Groq + Mistral + Gemini)
 
 
 
 
 
 
 
246
  interp = interpret_topic(
247
+ tid, [w for w, _ in kw_pairs], res["representative_docs"].get(tid, []),
248
+ client, mistral_key, gemini_key, res["topic_freq"].get(tid, 0),
249
+ res["representative_docs"].get(tid, [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  )
251
+ interpretations[tid] = interp
252
+ logger.info(f"Interpreted {tid}: {interp.label}")
253
+
254
+ interp_list = [asdict(i) for i in interpretations.values()]
255
+ # Fix numpy serialisation before saving
256
+ clean_data = convert_numpy_types(interp_list)
257
+ with open(output_json, "w") as f:
258
+ json.dump(clean_data, f, indent=2)
259
+ df = pd.DataFrame(clean_data)
260
+ if not df.empty:
261
+ df["keywords"] = df["keywords"].apply(lambda x: ", ".join(x) if isinstance(x, list) else str(x))
262
+ df.to_csv(output_csv, index=False)
263
+
264
+ return {"interpretations": interpretations, "json_path": output_json, "csv_path": output_csv}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
+ if __name__ == "__main__": pass