anujjuna commited on
Commit
7753f36
·
verified ·
1 Parent(s): e5b7184

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +477 -816
agent.py CHANGED
@@ -1,816 +1,477 @@
1
- """
2
- agent.py
3
- --------
4
- LLM-driven topic interpretation and classification module.
5
-
6
- For each BERTopic-discovered topic this agent:
7
- 1. Generates a concise, human-readable label.
8
- 2. Assigns the topic to a taxonomy category.
9
- 3. Classifies the topic as MAPPED or NOVEL.
10
-
11
- It then cross-compares title-derived and abstract-derived topics and writes:
12
- - taxonomy_map.json – full classification for every topic
13
- - comparison.csv – side-by-side diff of title vs. abstract topics
14
- """
15
-
16
- from __future__ import annotations
17
-
18
- import json
19
- import logging
20
- import os
21
- import time
22
- from dataclasses import dataclass, asdict
23
- from typing import Optional
24
-
25
- import pandas as pd
26
- import requests
27
- from groq import Groq
28
-
29
- # ---------------------------------------------------------------------------
30
- # Logging
31
- # ---------------------------------------------------------------------------
32
- logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
33
- logger = logging.getLogger(__name__)
34
-
35
- # ---------------------------------------------------------------------------
36
- # Constants
37
- # ---------------------------------------------------------------------------
38
- DEFAULT_MODEL = "llama-3.1-8b-instant"
39
- MISTRAL_DEFAULT_MODEL = "mistral-small-latest" # --- Dual LLM Validation ---
40
- DEFAULT_TAXONOMY_CATEGORIES = [
41
- "Artificial Intelligence",
42
- "Machine Learning",
43
- "Natural Language Processing",
44
- "Computer Vision",
45
- "Information Systems",
46
- "Healthcare & Bioinformatics",
47
- "Finance & Economics",
48
- "Cybersecurity",
49
- "Human-Computer Interaction",
50
- "Robotics & Automation",
51
- "Education Technology",
52
- "Environmental Science",
53
- "Social Sciences",
54
- "Data Engineering",
55
- "Other",
56
- ]
57
-
58
- CLASSIFICATION_OPTIONS = ("MAPPED", "NOVEL")
59
-
60
- # ---------------------------------------------------------------------------
61
- # Data Classes
62
- # ---------------------------------------------------------------------------
63
- @dataclass
64
- class TopicInterpretation:
65
- """Structured interpretation for a single topic."""
66
- source: str
67
- topic_id: int
68
- keywords: list[str]
69
- label: str
70
- taxonomy_category: str
71
- classification: str
72
- reasoning: str
73
- # --- Dual LLM Validation ---
74
- validation_status: str = "PENDING" # AGREED | DISAGREED | REVIEW_REQUIRED
75
- confidence: str = "MEDIUM" # HIGH | MEDIUM
76
- label_source: str = "groq" # groq | fallback
77
-
78
-
79
- @dataclass
80
- class ComparisonRow:
81
- """One row in the title-vs-abstract comparison table."""
82
- topic_id: int
83
- title_label: str
84
- title_category: str
85
- title_classification: str
86
- abstract_label: str
87
- abstract_category: str
88
- abstract_classification: str
89
- overlap_keywords: str # comma-separated shared keywords
90
- difference_note: str # LLM-generated note on differences
91
-
92
-
93
- # ---------------------------------------------------------------------------
94
- # OpenAI Client
95
- # ---------------------------------------------------------------------------
96
- def build_openai_client(api_key: Optional[str] = None):
97
- key = api_key or os.getenv("GROQ_API_KEY")
98
- if not key:
99
- raise ValueError(
100
- "No Groq API key provided. "
101
- "Pass api_key= or set the GROQ_API_KEY environment variable."
102
- )
103
- return Groq(api_key=key, max_retries=0)
104
-
105
-
106
-
107
-
108
- # ---------------------------------------------------------------------------
109
- # Helpers
110
- # ---------------------------------------------------------------------------
111
- def _ensure_string(x) -> str:
112
- """Safely convert any input (list, None, etc.) to a string."""
113
- if isinstance(x, list):
114
- return " ".join(str(i) for i in x)
115
- if x is None:
116
- return ""
117
- return str(x)
118
-
119
- def _safe_capitalize(s: str) -> str:
120
- """Capitalize only the first letter, keeping the rest as is (unlike .capitalize())."""
121
- s = _ensure_string(s).strip()
122
- if not s:
123
- return ""
124
- return s[0].upper() + s[1:]
125
-
126
- # ---------------------------------------------------------------------------
127
- # Prompt Builders
128
- # ---------------------------------------------------------------------------
129
- def _build_interpretation_prompt(
130
- keywords: list[str],
131
- sample_texts: list[str],
132
- taxonomy_categories: list[str],
133
- ) -> str:
134
- """Return the user prompt for labelling and classifying a single topic."""
135
- kw_str = ", ".join(keywords)
136
- samples_str = "\n".join(f" - {t}" for t in sample_texts[:5])
137
- cats_str = "\n".join(f" - {c}" for c in taxonomy_categories)
138
-
139
- return f"""You are an expert research analyst. A topic modelling algorithm has produced the following topic.
140
-
141
- TOP KEYWORDS:
142
- {kw_str}
143
-
144
- SAMPLE DOCUMENTS FOR THIS TOPIC:
145
- {samples_str}
146
-
147
- AVAILABLE TAXONOMY CATEGORIES:
148
- {cats_str}
149
-
150
- Your task:
151
- 1. Write a concise label (≤8 words) that captures the essence of this topic.
152
- 2. Assign it to ONE category from the list above. Use "Other" only as a last resort.
153
- 3. Classify it as MAPPED (fits an existing, well-established research area) or NOVEL (represents an emerging or cross-disciplinary theme not well-represented in standard taxonomies).
154
- 4. Provide one sentence of reasoning.
155
-
156
- Respond ONLY with valid JSON in exactly this schema – no markdown fences:
157
- {{
158
- "label": "<short label>",
159
- "taxonomy_category": "<one of the listed categories>",
160
- "classification": "MAPPED" | "NOVEL",
161
- "reasoning": "<one sentence>"
162
- }}"""
163
-
164
-
165
- # --- Dual LLM Validation ---
166
- def _fallback_label_from_keywords(keywords: list[str], topic_id: int) -> tuple[str, str]:
167
- """Deterministic keyword-to-label heuristic fallback."""
168
- kw_set = set([k.lower() for k in keywords])
169
-
170
- # Mapping heuristics
171
- mappings = [
172
- ({"privacy", "data", "security", "protection"}, "Digital Privacy and Security Risks", "Cybersecurity"),
173
- ({"ai", "chatbots", "agents", "conversational", "interaction", "assistant"}, "Conversational AI and Human Interaction", "Artificial Intelligence"),
174
- ({"gaming", "players", "video", "games", "engagement"}, "Gaming and User Engagement Patterns", "Human-Computer Interaction"),
175
- ({"vr", "virtual", "immersive", "training", "reality"}, "Virtual Reality and Immersive Training", "Robotics & Automation"),
176
- ({"patient", "healthcare", "medical", "clinical", "hospital"}, "Healthcare Technology and Patient Care", "Healthcare & Bioinformatics"),
177
- ({"shopping", "commerce", "purchase", "ecommerce", "consumer"}, "E-commerce and Consumer Behavior", "Finance & Economics"),
178
- ({"internet", "addiction", "adolescents", "youth", "behavior"}, "Internet Addiction and Adolescent Behavior", "Social Sciences"),
179
- ({"gamification", "learning", "education", "student", "classroom"}, "Gamification in Learning and Interaction", "Education Technology"),
180
- ({"neural", "network", "deep", "learning", "cnn", "transformer"}, "Deep Learning Architectures", "Machine Learning"),
181
- ({"graph", "knowledge", "relational", "embedding"}, "Knowledge Graphs and Relational Data", "Data Engineering"),
182
- ]
183
-
184
- for trigger_kws, fallback_label, fallback_cat in mappings:
185
- if any(tk in kw_set for tk in trigger_kws):
186
- return fallback_label, fallback_cat
187
-
188
- # Generic fallback if no specific rule matches
189
- main_kws = ", ".join(_safe_capitalize(k) for k in keywords[:2])
190
- label = f"Study on {', '.join(keywords[:3])}"
191
- return label, "Other"
192
-
193
- def _build_validation_prompt(keywords, groq_label, groq_category):
194
- return f"""
195
- You are reviewing topic classification for research papers.
196
-
197
- Keywords: {', '.join(keywords[:8])}
198
- Proposed label: {groq_label}
199
- Proposed category: {groq_category}
200
-
201
- Instructions:
202
- - If label and category reasonably match the keywords → say YES
203
- - If there is a clear mismatch → say NO
204
- - Small wording differences are OK
205
- - Be balanced: do not be too strict or too lenient
206
-
207
- Respond ONLY in JSON:
208
- {{
209
- "AGREEMENT": "YES" or "NO",
210
- "CONFIDENCE": "HIGH", "MEDIUM", or "LOW",
211
- "REASON": "<short explanation>"
212
- }}
213
- """
214
-
215
-
216
- def _call_mistral_validation(
217
- mistral_api_key,
218
- keywords,
219
- groq_label,
220
- groq_category,
221
- model="mistral-small-latest",
222
- ):
223
- if not mistral_api_key:
224
- return {}
225
-
226
- prompt = _build_validation_prompt(keywords, groq_label, groq_category)
227
-
228
- try:
229
- response = requests.post(
230
- "https://api.mistral.ai/v1/chat/completions",
231
- headers={
232
- "Authorization": f"Bearer {mistral_api_key}",
233
- "Content-Type": "application/json",
234
- },
235
- json={
236
- "model": model,
237
- "messages": [{"role": "user", "content": prompt}],
238
- "temperature": 0.1,
239
- },
240
- timeout=20,
241
- )
242
-
243
- data = response.json()
244
- raw = data["choices"][0]["message"]["content"].strip()
245
-
246
- raw = raw.replace("```json", "").replace("```", "").strip()
247
- start, end = raw.find("{"), raw.rfind("}") + 1
248
- return json.loads(raw[start:end])
249
-
250
- except Exception as e:
251
- logger.warning(f"Mistral validation failed: {e}")
252
- return {} # Ensure fallback logic triggers correctly
253
-
254
-
255
- def _build_comparison_prompt(
256
- topic_id: int,
257
- title_interp: TopicInterpretation,
258
- abstract_interp: TopicInterpretation,
259
- ) -> str:
260
- """Return the user prompt for comparing a title topic to an abstract topic."""
261
- return f"""You are comparing two topic representations for Topic ID {topic_id}.
262
-
263
- TITLE-BASED TOPIC
264
- Label : {title_interp.label}
265
- Category : {title_interp.taxonomy_category}
266
- Class : {title_interp.classification}
267
- Keywords : {', '.join(title_interp.keywords)}
268
-
269
- ABSTRACT-BASED TOPIC
270
- Label : {abstract_interp.label}
271
- Category : {abstract_interp.taxonomy_category}
272
- Class : {abstract_interp.classification}
273
- Keywords : {', '.join(abstract_interp.keywords)}
274
-
275
- In one concise sentence, describe the most meaningful difference (or similarity) between these two topic representations.
276
- Respond with ONLY the sentence – no JSON, no markdown."""
277
-
278
-
279
- # ---------------------------------------------------------------------------
280
- # LLM Calls
281
- # ---------------------------------------------------------------------------
282
- def _call_llm_json(
283
- client,
284
- prompt: str,
285
- model: str,
286
- retries: int = 1,
287
- backoff: float = 1.0,
288
- ) -> dict:
289
- """
290
- Call the OpenAI chat completion endpoint and parse the response as JSON.
291
-
292
- Parameters
293
- ----------
294
- client : OpenAI
295
- prompt : str
296
- model : str
297
- retries : int
298
- backoff : float
299
- Seconds to wait between retries (exponential).
300
-
301
- Returns
302
- -------
303
- dict
304
- Parsed JSON response.
305
- """
306
- for attempt in range(1, retries + 1):
307
- try:
308
- response = client.chat.completions.create(
309
- model=model,
310
- messages=[{"role": "user", "content": prompt}],
311
- temperature=0.2,
312
- timeout=8,
313
- )
314
- raw = response.choices[0].message.content.strip()
315
- raw = raw.replace("```json", "").replace("```", "").strip()
316
- start = raw.find("{")
317
- end = raw.rfind("}") + 1
318
- if start == -1 or end == 0:
319
- raise ValueError("No JSON object found in response")
320
- return json.loads(raw[start:end])
321
-
322
- except (json.JSONDecodeError, ValueError) as exc:
323
- logger.warning("Attempt %d – Parse error: %s", attempt, exc)
324
- except Exception as exc:
325
- logger.warning("Attempt %d API error: %s", attempt, exc)
326
- if "rate limit" in str(exc).lower():
327
- time.sleep(1)
328
- if attempt < retries:
329
- time.sleep(0.5)
330
-
331
- return {}
332
-
333
-
334
- def _call_llm_text(
335
- client,
336
- prompt: str,
337
- model: str,
338
- ) -> str:
339
- """Call the OpenAI endpoint and return plain text."""
340
- try:
341
- response = client.chat.completions.create(
342
- model=model,
343
- messages=[{"role": "user", "content": prompt}],
344
- temperature=0.3,
345
- )
346
- return response.choices[0].message.content.strip()
347
- except Exception as exc:
348
- logger.warning("LLM text call failed: %s", exc)
349
- return ""
350
-
351
-
352
- # --- Dual LLM Validation ---
353
-
354
-
355
-
356
- def _decide_validation(groq_category: str, mistral_result: dict) -> tuple[str, str]:
357
- """
358
- Decision logic – Groq is authoritative, Mistral is validator.
359
- """
360
-
361
- if not mistral_result:
362
- return "AGREED", "LOW"
363
-
364
- agreement = mistral_result.get("AGREEMENT", "NO").upper()
365
- confidence = mistral_result.get("CONFIDENCE", "MEDIUM").upper()
366
- suggested = mistral_result.get("SUGGESTED_CATEGORY", groq_category).strip()
367
-
368
- # Extract root categories
369
- groq_root = groq_category.split("&")[0].strip().lower()
370
- suggested_root = suggested.split("&")[0].strip().lower()
371
-
372
- # ✅ Case 1: Agreement
373
- if agreement == "YES":
374
- return "AGREED", confidence
375
-
376
- # Case 2: Disagreement (handle smartly)
377
- if agreement == "NO":
378
-
379
- # Strong disagreement → flag clearly
380
- if confidence == "HIGH":
381
- if groq_root != suggested_root:
382
- return "REVIEW_REQUIRED", "HIGH"
383
- return "DISAGREED", "HIGH"
384
-
385
- # Medium disagreement → partial trust
386
- if confidence == "MEDIUM":
387
- if groq_root != suggested_root:
388
- return "REVIEW_REQUIRED", "MEDIUM"
389
- return "DISAGREED", "MEDIUM"
390
-
391
- # Low confidence → be lenient
392
- return "AGREED", "LOW"
393
-
394
- return "AGREED", "LOW"
395
-
396
-
397
- # ---------------------------------------------------------------------------
398
- # Core Interpretation
399
- # ---------------------------------------------------------------------------
400
- def interpret_topic(
401
- client,
402
- source: str,
403
- topic_id: int,
404
- keywords: list[str],
405
- sample_texts: list[str],
406
- taxonomy_categories: list[str],
407
- model: str = DEFAULT_MODEL,
408
- mistral_api_key: Optional[str] = None,
409
- mistral_model: str = MISTRAL_DEFAULT_MODEL,
410
- ) -> TopicInterpretation:
411
- # Step 1: Groq generates label / category / classification
412
- prompt = _build_interpretation_prompt(keywords, sample_texts, taxonomy_categories)
413
- data = _call_llm_json(client, prompt, model, retries=2)
414
-
415
- label_source = "groq"
416
- if not data:
417
- # Fallback to heuristic if Groq fails
418
- fallback_label, fallback_cat = _fallback_label_from_keywords(keywords, topic_id)
419
- label = fallback_label
420
- category = fallback_cat
421
- classification = "MAPPED"
422
- reasoning = "Generated via keyword heuristics due to LLM timeout."
423
- label_source = "fallback"
424
- else:
425
- label = _ensure_string(data.get("label", "Unknown Topic"))
426
- category = _ensure_string(data.get("taxonomy_category", "Other"))
427
- classification = _ensure_string(data.get("classification", "MAPPED")).upper()
428
- reasoning = _ensure_string(data.get("reasoning", ""))
429
-
430
- if label == "Unknown Topic":
431
- fallback_label, fallback_cat = _fallback_label_from_keywords(keywords, topic_id)
432
- label = fallback_label
433
- category = fallback_cat
434
- label_source = "fallback"
435
-
436
- # Final normalization and safe capitalization
437
- label = _safe_capitalize(label)
438
- category = _safe_capitalize(category)
439
-
440
- if classification not in CLASSIFICATION_OPTIONS:
441
- classification = "MAPPED"
442
-
443
- # Step 2 & 3: Mistral validates – Groq stays authoritative
444
- mistral_result = _call_mistral_validation(
445
- mistral_api_key, keywords, label, category, mistral_model
446
- )
447
- validation_status, confidence = _decide_validation(category, mistral_result)
448
-
449
- logger.info(
450
- "[%s] Topic %d → '%s' (%s) | %s | val=%s conf=%s",
451
- source, topic_id, label, label_source, category, validation_status, confidence,
452
- )
453
- return TopicInterpretation(
454
- source=source,
455
- topic_id=topic_id,
456
- keywords=keywords,
457
- label=label,
458
- taxonomy_category=category,
459
- classification=classification,
460
- reasoning=reasoning,
461
- validation_status=validation_status,
462
- confidence=confidence,
463
- label_source=label_source
464
- )
465
-
466
-
467
- def interpret_all_topics(
468
- client,
469
- source: str,
470
- topic_keywords: dict[int, list[tuple[str, float]]],
471
- topic_docs: dict[int, list[str]],
472
- taxonomy_categories: list[str] = DEFAULT_TAXONOMY_CATEGORIES,
473
- model: str = DEFAULT_MODEL,
474
- mistral_api_key: Optional[str] = None, # --- Dual LLM Validation ---
475
- mistral_model: str = MISTRAL_DEFAULT_MODEL,
476
- ) -> dict[int, TopicInterpretation]:
477
- """Interpret every topic for a given source with optional Mistral validation."""
478
- interpretations: dict[int, TopicInterpretation] = {}
479
-
480
- MAX_TOPICS = 200 # Increased for fuller comparison
481
- selected_topics = dict(list(topic_keywords.items())[:MAX_TOPICS])
482
-
483
- for topic_id, kw_pairs in selected_topics.items():
484
- keywords = [w for w, _ in kw_pairs]
485
- samples = topic_docs.get(topic_id, [])[:5]
486
-
487
- interp = interpret_topic(
488
- client=client,
489
- source=source,
490
- topic_id=topic_id,
491
- keywords=keywords,
492
- sample_texts=samples,
493
- taxonomy_categories=taxonomy_categories,
494
- model=model,
495
- mistral_api_key=mistral_api_key,
496
- mistral_model=mistral_model,
497
- )
498
-
499
- interpretations[topic_id] = interp
500
- time.sleep(2) # API rate limiting
501
-
502
- return interpretations
503
-
504
-
505
- # ---------------------------------------------------------------------------
506
- # Cross-Source Comparison
507
- # ---------------------------------------------------------------------------
508
- def _get_overlap_keywords(a: TopicInterpretation, b: TopicInterpretation) -> list[str]:
509
- """Return keywords shared between two topic interpretations."""
510
- return list(set(a.keywords) & set(b.keywords))
511
-
512
-
513
- def compare_topics(
514
- client,
515
- title_interpretations: dict[int, TopicInterpretation],
516
- abstract_interpretations: dict[int, TopicInterpretation],
517
- model: str = DEFAULT_MODEL,
518
- ) -> list[ComparisonRow]:
519
- """
520
- Pair topics that share the same topic_id across title and abstract sources
521
- and produce a comparison row for each shared ID.
522
-
523
- Parameters
524
- ----------
525
- client : OpenAI
526
- title_interpretations : dict[int, TopicInterpretation]
527
- abstract_interpretations : dict[int, TopicInterpretation]
528
- model : str
529
-
530
- Returns
531
- -------
532
- list[ComparisonRow]
533
- """
534
- shared_ids = sorted(
535
- set(title_interpretations) & set(abstract_interpretations)
536
- )
537
- rows: list[ComparisonRow] = []
538
-
539
- for tid in shared_ids:
540
- t_interp = title_interpretations[tid]
541
- a_interp = abstract_interpretations[tid]
542
- overlap = _get_overlap_keywords(t_interp, a_interp)
543
- diff_note = _call_llm_text(
544
- client,
545
- _build_comparison_prompt(tid, t_interp, a_interp),
546
- model,
547
- )
548
- if not diff_note or len(diff_note.strip()) < 5:
549
- diff_note = "Minor or no significant difference"
550
-
551
- rows.append(
552
- ComparisonRow(
553
- topic_id=tid,
554
- title_label=t_interp.label,
555
- title_category=t_interp.taxonomy_category,
556
- title_classification=t_interp.classification,
557
- abstract_label=a_interp.label,
558
- abstract_category=a_interp.taxonomy_category,
559
- abstract_classification=a_interp.classification,
560
- overlap_keywords=", ".join(overlap) if overlap else "none",
561
- difference_note=diff_note,
562
- )
563
- )
564
- logger.info("Compared topic %d across sources.", tid)
565
-
566
- return rows
567
-
568
-
569
- # ---------------------------------------------------------------------------
570
- # Output Writers
571
- # ---------------------------------------------------------------------------
572
- def build_taxonomy_map(
573
- title_interpretations: dict[int, TopicInterpretation],
574
- abstract_interpretations: dict[int, TopicInterpretation],
575
- ) -> dict:
576
- """
577
- Merge title and abstract interpretations into a single taxonomy map dict.
578
-
579
- Returns
580
- -------
581
- dict
582
- Structured taxonomy map ready for JSON serialisation.
583
- """
584
- def _serialize(interps: dict[int, TopicInterpretation]) -> list[dict]:
585
- return [asdict(v) for v in interps.values()]
586
-
587
- return {
588
- "titles": _serialize(title_interpretations),
589
- "abstracts": _serialize(abstract_interpretations),
590
- }
591
-
592
-
593
- def save_taxonomy_map(taxonomy_map: dict, output_path: str = "taxonomy_map.json") -> None:
594
- """
595
- Write the taxonomy map to a JSON file.
596
-
597
- Parameters
598
- ----------
599
- taxonomy_map : dict
600
- output_path : str
601
- """
602
- with open(output_path, "w", encoding="utf-8") as fh:
603
- json.dump(taxonomy_map, fh, indent=2, ensure_ascii=False)
604
- logger.info("Taxonomy map saved → %s", output_path)
605
-
606
-
607
- def save_comparison_csv(
608
- comparison_rows: list[ComparisonRow],
609
- output_path: str = "comparison.csv",
610
- ) -> None:
611
- """
612
- Write the comparison rows to a CSV file.
613
-
614
- Parameters
615
- ----------
616
- comparison_rows : list[ComparisonRow]
617
- output_path : str
618
- """
619
- if not comparison_rows:
620
- logger.warning("No comparison rows to save.")
621
- return
622
-
623
- df = pd.DataFrame([asdict(r) for r in comparison_rows])
624
- df.to_csv(output_path, index=False)
625
- logger.info("Comparison CSV saved → %s", output_path)
626
-
627
-
628
- # ---------------------------------------------------------------------------
629
- # Helper: Build topic_docs mapping from BERTopic output
630
- # ---------------------------------------------------------------------------
631
- def build_topic_docs_map(
632
- raw_texts: list[str],
633
- topic_assignments: list[int],
634
- ) -> dict[int, list[str]]:
635
- """
636
- Group raw documents by their assigned topic ID.
637
-
638
- Parameters
639
- ----------
640
- raw_texts : list[str]
641
- Original (unprocessed) text documents.
642
- topic_assignments : list[int]
643
- Topic ID assigned to each document by BERTopic (parallel to raw_texts).
644
-
645
- Returns
646
- -------
647
- dict[int, list[str]]
648
- Mapping of topic_id → list of documents belonging to that topic.
649
- """
650
- mapping: dict[int, list[str]] = {}
651
- for doc, tid in zip(raw_texts, topic_assignments):
652
- if tid == -1:
653
- continue
654
- mapping.setdefault(tid, []).append(doc)
655
- return mapping
656
-
657
-
658
- # ---------------------------------------------------------------------------
659
- # High-Level Pipeline
660
- # ---------------------------------------------------------------------------
661
- def run_agent(
662
- title_topic_keywords: dict[int, list[tuple[str, float]]],
663
- abstract_topic_keywords: dict[int, list[tuple[str, float]]],
664
- title_topic_assignments: list[int],
665
- abstract_topic_assignments: list[int],
666
- raw_titles: list[str],
667
- raw_abstracts: list[str],
668
- api_key: Optional[str] = None,
669
- model: str = DEFAULT_MODEL,
670
- taxonomy_categories: list[str] = DEFAULT_TAXONOMY_CATEGORIES,
671
- taxonomy_map_path: str = "taxonomy_map.json",
672
- comparison_csv_path: str = "comparison.csv",
673
- mistral_api_key: Optional[str] = None, # --- Dual LLM Validation ---
674
- mistral_model: str = MISTRAL_DEFAULT_MODEL,
675
- ) -> dict:
676
- """
677
- End-to-end agent pipeline:
678
- 1. Interpret title topics via LLM
679
- 2. Interpret abstract topics via LLM
680
- 3. Compare cross-source topics
681
- 4. Write taxonomy_map.json and comparison.csv
682
-
683
- Parameters
684
- ----------
685
- title_topic_keywords : dict
686
- Output of tools.extract_topics()["topic_keywords"] for titles.
687
- abstract_topic_keywords : dict
688
- Output of tools.extract_topics()["topic_keywords"] for abstracts.
689
- title_topic_assignments : list[int]
690
- Output of tools.extract_topics()["topics"] for titles.
691
- abstract_topic_assignments : list[int]
692
- Output of tools.extract_topics()["topics"] for abstracts.
693
- raw_titles : list[str]
694
- Original (unprocessed) title strings.
695
- raw_abstracts : list[str]
696
- Original (unprocessed) abstract strings.
697
- api_key : str, optional
698
- OpenAI API key (falls back to OPENAI_API_KEY env var).
699
- model : str
700
- OpenAI model to use (default gpt-4o-mini).
701
- taxonomy_categories : list[str]
702
- Taxonomy buckets the LLM may assign topics to.
703
- taxonomy_map_path : str
704
- Output path for taxonomy_map.json.
705
- comparison_csv_path : str
706
- Output path for comparison.csv.
707
-
708
- Returns
709
- -------
710
- dict with keys
711
- title_interpretations – dict[int, TopicInterpretation]
712
- abstract_interpretations – dict[int, TopicInterpretation]
713
- comparison_rows – list[ComparisonRow]
714
- taxonomy_map – dict (JSON-serialisable)
715
- """
716
- client = build_openai_client(api_key)
717
- mistral_api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
718
-
719
- # --- Build raw-text lookup maps ---
720
- title_docs_map = build_topic_docs_map(raw_titles, title_topic_assignments)
721
- abstract_docs_map = build_topic_docs_map(raw_abstracts, abstract_topic_assignments)
722
-
723
- # --- Interpret topics ---
724
- logger.info("Interpreting TITLE topics …")
725
- title_interps = interpret_all_topics(
726
- client=client,
727
- source="titles",
728
- topic_keywords=title_topic_keywords,
729
- topic_docs=title_docs_map,
730
- taxonomy_categories=taxonomy_categories,
731
- model=model,
732
- mistral_api_key=mistral_api_key,
733
- mistral_model=mistral_model,
734
- )
735
-
736
- logger.info("Interpreting ABSTRACT topics …")
737
- abstract_interps = interpret_all_topics(
738
- client=client,
739
- source="abstracts",
740
- topic_keywords=abstract_topic_keywords,
741
- topic_docs=abstract_docs_map,
742
- taxonomy_categories=taxonomy_categories,
743
- model=model,
744
- mistral_api_key=mistral_api_key,
745
- mistral_model=mistral_model,
746
- )
747
-
748
- # --- Compare ---
749
- logger.info("Comparing title vs. abstract topics …")
750
- comparison_rows = compare_topics(client, title_interps, abstract_interps, model)
751
-
752
- # --- Persist ---
753
- taxonomy_map = build_taxonomy_map(title_interps, abstract_interps)
754
- save_taxonomy_map(taxonomy_map, taxonomy_map_path)
755
- save_comparison_csv(comparison_rows, comparison_csv_path)
756
-
757
- return {
758
- "title_interpretations": title_interps,
759
- "abstract_interpretations": abstract_interps,
760
- "comparison_rows": comparison_rows,
761
- "taxonomy_map": taxonomy_map,
762
- }
763
-
764
-
765
- # ---------------------------------------------------------------------------
766
- # CLI Entry Point
767
- # ---------------------------------------------------------------------------
768
- if __name__ == "__main__":
769
- """
770
- Demo / smoke-test: runs agent on synthetic topic data.
771
- Set OPENAI_API_KEY in your environment before running.
772
- """
773
- DEMO_TITLE_KEYWORDS: dict[int, list[tuple[str, float]]] = {
774
- 0: [("neural", 0.9), ("network", 0.85), ("deep", 0.8), ("learning", 0.75), ("training", 0.7)],
775
- 1: [("blockchain", 0.88), ("transaction", 0.82), ("ledger", 0.78), ("consensus", 0.74), ("crypto", 0.7)],
776
- }
777
- DEMO_ABSTRACT_KEYWORDS: dict[int, list[tuple[str, float]]] = {
778
- 0: [("deep", 0.91), ("model", 0.87), ("classification", 0.82), ("accuracy", 0.78), ("dataset", 0.74)],
779
- 1: [("distributed", 0.86), ("blockchain", 0.81), ("smart", 0.77), ("contract", 0.73), ("peer", 0.68)],
780
- }
781
-
782
- sample_titles = [
783
- "Deep Learning for Image Classification",
784
- "Neural Networks in Healthcare",
785
- "Blockchain and Distributed Ledger Technology",
786
- "Smart Contracts in Finance",
787
- ]
788
- sample_abstracts = [
789
- "We propose a deep learning model achieving state-of-the-art accuracy on benchmark datasets.",
790
- "A convolutional network trained for medical image classification.",
791
- "This paper surveys blockchain consensus mechanisms and distributed ledger architectures.",
792
- "We implement smart contracts for automated financial transactions on a public blockchain.",
793
- ]
794
-
795
- title_assignments = [0, 0, 1, 1]
796
- abstract_assignments = [0, 0, 1, 1]
797
-
798
- results = run_agent(
799
- title_topic_keywords=DEMO_TITLE_KEYWORDS,
800
- abstract_topic_keywords=DEMO_ABSTRACT_KEYWORDS,
801
- title_topic_assignments=title_assignments,
802
- abstract_topic_assignments=abstract_assignments,
803
- raw_titles=sample_titles,
804
- raw_abstracts=sample_abstracts,
805
- taxonomy_map_path="taxonomy_map.json",
806
- comparison_csv_path="comparison.csv",
807
- )
808
-
809
- print("\n=== Taxonomy Map (titles) ===")
810
- for interp in results["taxonomy_map"]["titles"]:
811
- print(f" [{interp['topic_id']}] {interp['label']} | {interp['taxonomy_category']} | {interp['classification']}")
812
-
813
- print("\n=== Comparison Rows ===")
814
- for row in results["comparison_rows"]:
815
- print(f" Topic {row.topic_id}: '{row.title_label}' vs '{row.abstract_label}'")
816
- print(f" Note: {row.difference_note}")
 
1
+ """
2
+ agent.py
3
+ --------
4
+ LLM-driven topic interpretation and classification module.
5
+ Heavy imports are lazy-loaded inside functions to stay within 2GB RAM on free HF Spaces.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import logging
12
+ import os
13
+ import time
14
+ from dataclasses import dataclass, asdict
15
+ from typing import Optional
16
+
17
+ import pandas as pd
18
+ import requests
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Logging
22
+ # ---------------------------------------------------------------------------
23
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Constants
28
+ # ---------------------------------------------------------------------------
29
+ DEFAULT_MODEL = "llama-3.1-8b-instant"
30
+ MISTRAL_DEFAULT_MODEL = "mistral-small-latest"
31
+ DEFAULT_TAXONOMY_CATEGORIES = [
32
+ "Artificial Intelligence",
33
+ "Machine Learning",
34
+ "Natural Language Processing",
35
+ "Computer Vision",
36
+ "Information Systems",
37
+ "Healthcare & Bioinformatics",
38
+ "Finance & Economics",
39
+ "Cybersecurity",
40
+ "Human-Computer Interaction",
41
+ "Robotics & Automation",
42
+ "Education Technology",
43
+ "Environmental Science",
44
+ "Social Sciences",
45
+ "Data Engineering",
46
+ "Other",
47
+ ]
48
+ CLASSIFICATION_OPTIONS = ("MAPPED", "NOVEL")
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Data Classes
53
+ # ---------------------------------------------------------------------------
54
+ @dataclass
55
+ class TopicInterpretation:
56
+ source: str
57
+ topic_id: int
58
+ keywords: list[str]
59
+ label: str
60
+ taxonomy_category: str
61
+ classification: str
62
+ reasoning: str
63
+ validation_status: str = "PENDING"
64
+ confidence: str = "MEDIUM"
65
+ label_source: str = "groq"
66
+
67
+
68
+ @dataclass
69
+ class ComparisonRow:
70
+ topic_id: int
71
+ title_label: str
72
+ title_category: str
73
+ title_classification: str
74
+ abstract_label: str
75
+ abstract_category: str
76
+ abstract_classification: str
77
+ overlap_keywords: str
78
+ difference_note: str
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Groq Client — lazy import
83
+ # ---------------------------------------------------------------------------
84
+ def build_groq_client(api_key: Optional[str] = None):
85
+ from groq import Groq # ← lazy
86
+ key = api_key or os.getenv("GROQ_API_KEY")
87
+ if not key:
88
+ raise ValueError(
89
+ "No Groq API key provided. "
90
+ "Pass api_key= or set the GROQ_API_KEY environment variable."
91
+ )
92
+ return Groq(api_key=key, max_retries=0)
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Helpers
97
+ # ---------------------------------------------------------------------------
98
+ def _ensure_string(x) -> str:
99
+ if isinstance(x, list):
100
+ return " ".join(str(i) for i in x)
101
+ if x is None:
102
+ return ""
103
+ return str(x)
104
+
105
+
106
+ def _safe_capitalize(s: str) -> str:
107
+ s = _ensure_string(s).strip()
108
+ if not s:
109
+ return ""
110
+ return s[0].upper() + s[1:]
111
+
112
+
113
+ # ---------------------------------------------------------------------------
114
+ # Prompt Builders
115
+ # ---------------------------------------------------------------------------
116
+ def _build_interpretation_prompt(keywords, sample_texts, taxonomy_categories) -> str:
117
+ kw_str = ", ".join(keywords)
118
+ samples_str = "\n".join(f" - {t}" for t in sample_texts[:5])
119
+ cats_str = "\n".join(f" - {c}" for c in taxonomy_categories)
120
+ return f"""You are an expert research analyst. A topic modelling algorithm has produced the following topic.
121
+
122
+ TOP KEYWORDS:
123
+ {kw_str}
124
+
125
+ SAMPLE DOCUMENTS FOR THIS TOPIC:
126
+ {samples_str}
127
+
128
+ AVAILABLE TAXONOMY CATEGORIES:
129
+ {cats_str}
130
+
131
+ Your task:
132
+ 1. Write a concise label (≤8 words) that captures the essence of this topic.
133
+ 2. Assign it to ONE category from the list above. Use "Other" only as a last resort.
134
+ 3. Classify it as MAPPED (fits an existing, well-established research area) or NOVEL (represents an emerging or cross-disciplinary theme not well-represented in standard taxonomies).
135
+ 4. Provide one sentence of reasoning.
136
+
137
+ Respond ONLY with valid JSON in exactly this schema – no markdown fences:
138
+ {{
139
+ "label": "<short label>",
140
+ "taxonomy_category": "<one of the listed categories>",
141
+ "classification": "MAPPED" | "NOVEL",
142
+ "reasoning": "<one sentence>"
143
+ }}"""
144
+
145
+
146
+ def _fallback_label_from_keywords(keywords, topic_id):
147
+ kw_set = set(k.lower() for k in keywords)
148
+ mappings = [
149
+ ({"privacy", "data", "security", "protection"}, "Digital Privacy and Security Risks", "Cybersecurity"),
150
+ ({"ai", "chatbots", "agents", "conversational"}, "Conversational AI and Human Interaction", "Artificial Intelligence"),
151
+ ({"gaming", "players", "video", "games"}, "Gaming and User Engagement Patterns", "Human-Computer Interaction"),
152
+ ({"vr", "virtual", "immersive", "training"}, "Virtual Reality and Immersive Training", "Robotics & Automation"),
153
+ ({"patient", "healthcare", "medical", "clinical"}, "Healthcare Technology and Patient Care", "Healthcare & Bioinformatics"),
154
+ ({"shopping", "commerce", "purchase", "ecommerce"}, "E-commerce and Consumer Behavior", "Finance & Economics"),
155
+ ({"internet", "addiction", "adolescents", "youth"}, "Internet Addiction and Adolescent Behavior","Social Sciences"),
156
+ ({"gamification", "learning", "education", "student"}, "Gamification in Learning and Interaction", "Education Technology"),
157
+ ({"neural", "network", "deep", "learning", "transformer"},"Deep Learning Architectures", "Machine Learning"),
158
+ ({"graph", "knowledge", "relational", "embedding"}, "Knowledge Graphs and Relational Data", "Data Engineering"),
159
+ ]
160
+ for trigger_kws, fallback_label, fallback_cat in mappings:
161
+ if any(tk in kw_set for tk in trigger_kws):
162
+ return fallback_label, fallback_cat
163
+ label = f"Study on {', '.join(keywords[:3])}"
164
+ return label, "Other"
165
+
166
+
167
+ def _build_validation_prompt(keywords, groq_label, groq_category) -> str:
168
+ return f"""
169
+ You are reviewing topic classification for research papers.
170
+
171
+ Keywords: {', '.join(keywords[:8])}
172
+ Proposed label: {groq_label}
173
+ Proposed category: {groq_category}
174
+
175
+ Instructions:
176
+ - If label and category reasonably match the keywords say YES
177
+ - If there is a clear mismatch say NO
178
+ - Small wording differences are OK
179
+
180
+ Respond ONLY in JSON:
181
+ {{
182
+ "AGREEMENT": "YES" or "NO",
183
+ "CONFIDENCE": "HIGH", "MEDIUM", or "LOW",
184
+ "REASON": "<short explanation>"
185
+ }}
186
+ """
187
+
188
+
189
+ def _build_comparison_prompt(topic_id, title_interp, abstract_interp) -> str:
190
+ return f"""You are comparing two topic representations for Topic ID {topic_id}.
191
+
192
+ TITLE-BASED TOPIC
193
+ Label : {title_interp.label}
194
+ Category : {title_interp.taxonomy_category}
195
+ Class : {title_interp.classification}
196
+ Keywords : {', '.join(title_interp.keywords)}
197
+
198
+ ABSTRACT-BASED TOPIC
199
+ Label : {abstract_interp.label}
200
+ Category : {abstract_interp.taxonomy_category}
201
+ Class : {abstract_interp.classification}
202
+ Keywords : {', '.join(abstract_interp.keywords)}
203
+
204
+ In one concise sentence, describe the most meaningful difference (or similarity) between these two topic representations.
205
+ Respond with ONLY the sentence no JSON, no markdown."""
206
+
207
+
208
+ # ---------------------------------------------------------------------------
209
+ # Mistral Validation — uses requests (already imported at top)
210
+ # ---------------------------------------------------------------------------
211
+ def _call_mistral_validation(mistral_api_key, keywords, groq_label, groq_category, model="mistral-small-latest"):
212
+ if not mistral_api_key:
213
+ return {}
214
+ prompt = _build_validation_prompt(keywords, groq_label, groq_category)
215
+ try:
216
+ response = requests.post(
217
+ "https://api.mistral.ai/v1/chat/completions",
218
+ headers={
219
+ "Authorization": f"Bearer {mistral_api_key}",
220
+ "Content-Type": "application/json",
221
+ },
222
+ json={
223
+ "model": model,
224
+ "messages": [{"role": "user", "content": prompt}],
225
+ "temperature": 0.1,
226
+ },
227
+ timeout=20,
228
+ )
229
+ data = response.json()
230
+ raw = data["choices"][0]["message"]["content"].strip()
231
+ raw = raw.replace("```json", "").replace("```", "").strip()
232
+ start, end = raw.find("{"), raw.rfind("}") + 1
233
+ return json.loads(raw[start:end])
234
+ except Exception as e:
235
+ logger.warning("Mistral validation failed: %s", e)
236
+ return {}
237
+
238
+
239
+ def _decide_validation(groq_category, mistral_result):
240
+ if not mistral_result:
241
+ return "AGREED", "LOW"
242
+ agreement = mistral_result.get("AGREEMENT", "NO").upper()
243
+ confidence = mistral_result.get("CONFIDENCE", "MEDIUM").upper()
244
+ suggested = mistral_result.get("SUGGESTED_CATEGORY", groq_category).strip()
245
+ groq_root = groq_category.split("&")[0].strip().lower()
246
+ suggested_root = suggested.split("&")[0].strip().lower()
247
+
248
+ if agreement == "YES":
249
+ return "AGREED", confidence
250
+ if agreement == "NO":
251
+ if confidence == "HIGH":
252
+ return ("REVIEW_REQUIRED" if groq_root != suggested_root else "DISAGREED"), "HIGH"
253
+ if confidence == "MEDIUM":
254
+ return ("REVIEW_REQUIRED" if groq_root != suggested_root else "DISAGREED"), "MEDIUM"
255
+ return "AGREED", "LOW"
256
+ return "AGREED", "LOW"
257
+
258
+
259
+ # ---------------------------------------------------------------------------
260
+ # LLM Calls
261
+ # ---------------------------------------------------------------------------
262
+ def _call_llm_json(client, prompt, model, retries=1, backoff=1.0) -> dict:
263
+ for attempt in range(1, retries + 1):
264
+ try:
265
+ response = client.chat.completions.create(
266
+ model=model,
267
+ messages=[{"role": "user", "content": prompt}],
268
+ temperature=0.2,
269
+ timeout=8,
270
+ )
271
+ raw = response.choices[0].message.content.strip()
272
+ raw = raw.replace("```json", "").replace("```", "").strip()
273
+ start, end = raw.find("{"), raw.rfind("}") + 1
274
+ if start == -1 or end == 0:
275
+ raise ValueError("No JSON object found in response")
276
+ return json.loads(raw[start:end])
277
+ except (json.JSONDecodeError, ValueError) as exc:
278
+ logger.warning("Attempt %d – Parse error: %s", attempt, exc)
279
+ except Exception as exc:
280
+ logger.warning("Attempt %d – API error: %s", attempt, exc)
281
+ if "rate limit" in str(exc).lower():
282
+ time.sleep(1)
283
+ if attempt < retries:
284
+ time.sleep(0.5)
285
+ return {}
286
+
287
+
288
+ def _call_llm_text(client, prompt, model) -> str:
289
+ try:
290
+ response = client.chat.completions.create(
291
+ model=model,
292
+ messages=[{"role": "user", "content": prompt}],
293
+ temperature=0.3,
294
+ )
295
+ return response.choices[0].message.content.strip()
296
+ except Exception as exc:
297
+ logger.warning("LLM text call failed: %s", exc)
298
+ return ""
299
+
300
+
301
+ # ---------------------------------------------------------------------------
302
+ # Core Interpretation
303
+ # ---------------------------------------------------------------------------
304
+ def interpret_topic(
305
+ client,
306
+ source, topic_id, keywords, sample_texts,
307
+ taxonomy_categories, model=DEFAULT_MODEL,
308
+ mistral_api_key=None, mistral_model=MISTRAL_DEFAULT_MODEL,
309
+ ) -> TopicInterpretation:
310
+ prompt = _build_interpretation_prompt(keywords, sample_texts, taxonomy_categories)
311
+ data = _call_llm_json(client, prompt, model, retries=2)
312
+
313
+ label_source = "groq"
314
+ if not data:
315
+ label, category = _fallback_label_from_keywords(keywords, topic_id)
316
+ classification = "MAPPED"
317
+ reasoning = "Generated via keyword heuristics due to LLM timeout."
318
+ label_source = "fallback"
319
+ else:
320
+ label = _ensure_string(data.get("label", "Unknown Topic"))
321
+ category = _ensure_string(data.get("taxonomy_category", "Other"))
322
+ classification = _ensure_string(data.get("classification", "MAPPED")).upper()
323
+ reasoning = _ensure_string(data.get("reasoning", ""))
324
+ if label == "Unknown Topic":
325
+ label, category = _fallback_label_from_keywords(keywords, topic_id)
326
+ label_source = "fallback"
327
+
328
+ label = _safe_capitalize(label)
329
+ category = _safe_capitalize(category)
330
+ if classification not in CLASSIFICATION_OPTIONS:
331
+ classification = "MAPPED"
332
+
333
+ mistral_result = _call_mistral_validation(mistral_api_key, keywords, label, category, mistral_model)
334
+ validation_status, confidence = _decide_validation(category, mistral_result)
335
+
336
+ logger.info("[%s] Topic %d → '%s' (%s) | val=%s conf=%s", source, topic_id, label, label_source, validation_status, confidence)
337
+ return TopicInterpretation(
338
+ source=source, topic_id=topic_id, keywords=keywords,
339
+ label=label, taxonomy_category=category,
340
+ classification=classification, reasoning=reasoning,
341
+ validation_status=validation_status, confidence=confidence,
342
+ label_source=label_source,
343
+ )
344
+
345
+
346
+ def interpret_all_topics(
347
+ client, source, topic_keywords, topic_docs,
348
+ taxonomy_categories=DEFAULT_TAXONOMY_CATEGORIES,
349
+ model=DEFAULT_MODEL, mistral_api_key=None,
350
+ mistral_model=MISTRAL_DEFAULT_MODEL,
351
+ ) -> dict[int, TopicInterpretation]:
352
+ interpretations = {}
353
+ for topic_id, kw_pairs in list(topic_keywords.items())[:200]:
354
+ keywords = [w for w, _ in kw_pairs]
355
+ samples = topic_docs.get(topic_id, [])[:5]
356
+ interp = interpret_topic(
357
+ client=client, source=source, topic_id=topic_id,
358
+ keywords=keywords, sample_texts=samples,
359
+ taxonomy_categories=taxonomy_categories, model=model,
360
+ mistral_api_key=mistral_api_key, mistral_model=mistral_model,
361
+ )
362
+ interpretations[topic_id] = interp
363
+ time.sleep(2)
364
+ return interpretations
365
+
366
+
367
+ # ---------------------------------------------------------------------------
368
+ # Cross-Source Comparison
369
+ # ---------------------------------------------------------------------------
370
+ def _get_overlap_keywords(a: TopicInterpretation, b: TopicInterpretation) -> list[str]:
371
+ return list(set(a.keywords) & set(b.keywords))
372
+
373
+
374
+ def compare_topics(client, title_interpretations, abstract_interpretations, model=DEFAULT_MODEL) -> list[ComparisonRow]:
375
+ shared_ids = sorted(set(title_interpretations) & set(abstract_interpretations))
376
+ rows = []
377
+ for tid in shared_ids:
378
+ t_interp = title_interpretations[tid]
379
+ a_interp = abstract_interpretations[tid]
380
+ overlap = _get_overlap_keywords(t_interp, a_interp)
381
+ diff_note = _call_llm_text(client, _build_comparison_prompt(tid, t_interp, a_interp), model)
382
+ if not diff_note or len(diff_note.strip()) < 5:
383
+ diff_note = "Minor or no significant difference"
384
+ rows.append(ComparisonRow(
385
+ topic_id=tid,
386
+ title_label=t_interp.label, title_category=t_interp.taxonomy_category,
387
+ title_classification=t_interp.classification,
388
+ abstract_label=a_interp.label, abstract_category=a_interp.taxonomy_category,
389
+ abstract_classification=a_interp.classification,
390
+ overlap_keywords=", ".join(overlap) if overlap else "none",
391
+ difference_note=diff_note,
392
+ ))
393
+ logger.info("Compared topic %d across sources.", tid)
394
+ return rows
395
+
396
+
397
+ # ---------------------------------------------------------------------------
398
+ # Output Builders
399
+ # ---------------------------------------------------------------------------
400
+ def build_taxonomy_map(title_interpretations, abstract_interpretations) -> dict:
401
+ def _serialize(interps):
402
+ return [asdict(v) for v in interps.values()]
403
+ return {"titles": _serialize(title_interpretations), "abstracts": _serialize(abstract_interpretations)}
404
+
405
+
406
+ def save_taxonomy_map(taxonomy_map, output_path="taxonomy_map.json"):
407
+ with open(output_path, "w", encoding="utf-8") as fh:
408
+ json.dump(taxonomy_map, fh, indent=2, ensure_ascii=False)
409
+ logger.info("Taxonomy map saved → %s", output_path)
410
+
411
+
412
+ def save_comparison_csv(comparison_rows, output_path="comparison.csv"):
413
+ if not comparison_rows:
414
+ logger.warning("No comparison rows to save.")
415
+ return
416
+ pd.DataFrame([asdict(r) for r in comparison_rows]).to_csv(output_path, index=False)
417
+ logger.info("Comparison CSV saved %s", output_path)
418
+
419
+
420
+ def build_topic_docs_map(raw_texts, topic_assignments) -> dict[int, list[str]]:
421
+ mapping: dict[int, list[str]] = {}
422
+ for doc, tid in zip(raw_texts, topic_assignments):
423
+ if tid == -1:
424
+ continue
425
+ mapping.setdefault(tid, []).append(doc)
426
+ return mapping
427
+
428
+
429
+ # ---------------------------------------------------------------------------
430
+ # High-Level Pipeline — Groq imported lazily here
431
+ # ---------------------------------------------------------------------------
432
+ def run_agent(
433
+ title_topic_keywords, abstract_topic_keywords,
434
+ title_topic_assignments, abstract_topic_assignments,
435
+ raw_titles, raw_abstracts,
436
+ api_key=None, model=DEFAULT_MODEL,
437
+ taxonomy_categories=DEFAULT_TAXONOMY_CATEGORIES,
438
+ taxonomy_map_path="taxonomy_map.json",
439
+ comparison_csv_path="comparison.csv",
440
+ mistral_api_key=None,
441
+ mistral_model=MISTRAL_DEFAULT_MODEL,
442
+ ) -> dict:
443
+ client = build_groq_client(api_key) # groq imported inside build_groq_client
444
+ mistral_api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
445
+
446
+ title_docs_map = build_topic_docs_map(raw_titles, title_topic_assignments)
447
+ abstract_docs_map = build_topic_docs_map(raw_abstracts, abstract_topic_assignments)
448
+
449
+ logger.info("Interpreting TITLE topics …")
450
+ title_interps = interpret_all_topics(
451
+ client=client, source="titles",
452
+ topic_keywords=title_topic_keywords, topic_docs=title_docs_map,
453
+ taxonomy_categories=taxonomy_categories, model=model,
454
+ mistral_api_key=mistral_api_key, mistral_model=mistral_model,
455
+ )
456
+
457
+ logger.info("Interpreting ABSTRACT topics …")
458
+ abstract_interps = interpret_all_topics(
459
+ client=client, source="abstracts",
460
+ topic_keywords=abstract_topic_keywords, topic_docs=abstract_docs_map,
461
+ taxonomy_categories=taxonomy_categories, model=model,
462
+ mistral_api_key=mistral_api_key, mistral_model=mistral_model,
463
+ )
464
+
465
+ logger.info("Comparing title vs. abstract topics …")
466
+ comparison_rows = compare_topics(client, title_interps, abstract_interps, model)
467
+
468
+ taxonomy_map = build_taxonomy_map(title_interps, abstract_interps)
469
+ save_taxonomy_map(taxonomy_map, taxonomy_map_path)
470
+ save_comparison_csv(comparison_rows, comparison_csv_path)
471
+
472
+ return {
473
+ "title_interpretations": title_interps,
474
+ "abstract_interpretations": abstract_interps,
475
+ "comparison_rows": comparison_rows,
476
+ "taxonomy_map": taxonomy_map,
477
+ }