reyansh2005's picture
nice
62e2807
"""
agent.py β€” Pipeline Orchestrator
Controls the full topic modelling pipeline:
load β†’ preprocess β†’ model titles β†’ model abstracts β†’ label β†’
compare β†’ map β†’ generate narrative β†’ generate reflection β†’ save outputs
All NLP/ML logic is delegated to tools.py.
This module handles sequencing, progress reporting, and file I/O.
"""
from __future__ import annotations
import os
import json
import pandas as pd
from pathlib import Path
from tools import (
preprocess_dataframe,
vectorize_texts,
run_topic_model,
extract_keywords,
label_topics_batch,
generate_label_from_keywords,
map_to_taxonomy,
compare_title_abstract_themes,
generate_narrative,
generate_reflection,
save_prompts,
PAJAIS_TAXONOMY,
)
# ── .env file loader (no python-dotenv dependency) ────────────────────────
def _load_env() -> None:
"""Read KEY=VALUE pairs from .env if present, without overwriting."""
env_path = Path(__file__).parent / ".env"
if not env_path.exists():
return
for line in env_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, val = line.split("=", 1)
key = key.strip()
val = val.strip().strip('"').strip("'")
if val and not os.getenv(key):
os.environ[key] = val
_load_env()
# ════════════════════════════════════════════════════════════════════════════
# Pipeline Agent
# ════════════════════════════════════════════════════════════════════════════
class TopicModellingAgent:
"""Orchestrates the full analysis pipeline from CSV upload to all outputs.
Attributes:
api_key: Optional LLM API key (Groq / Mistral / OpenAI).
provider: Optional provider name ('groq', 'mistral', 'openai').
df: Loaded and preprocessed DataFrame.
title_topics: Topics extracted from paper titles.
abstract_topics: Topics extracted from paper abstracts.
all_topics: Combined title + abstract topics.
taxonomy_map: PAJAIS mapping results.
comparison_df: Title vs abstract comparison DataFrame.
narrative: Generated ~500-word narrative text.
reflection: Generated ~250-word reflection text.
logs: Pipeline execution log messages.
"""
def __init__(self, api_key: str | None = None, provider: str | None = None):
self.api_key = api_key
self.provider = provider
self.df: pd.DataFrame | None = None
self.title_topics: list[dict] = []
self.abstract_topics: list[dict] = []
self.all_topics: list[dict] = []
self.taxonomy_map: list[dict] = []
self.comparison_df: pd.DataFrame | None = None
self.narrative: str = ""
self.reflection: str = ""
self.logs: list[str] = []
# ── Logging ───────────────────────────────────────────────────────────
def log(self, msg: str) -> None:
"""Append a log message and print to stdout."""
self.logs.append(msg)
try:
print(msg)
except UnicodeEncodeError:
# Windows cp1252 console can't render emoji
print(msg.encode("ascii", errors="replace").decode("ascii"))
# ── Step 1: Load & Validate ───────────────────────────────────────────
def load_and_validate(self, csv_path: str) -> pd.DataFrame:
"""Load CSV and validate that it contains 'title' and 'abstract'."""
self.log("πŸ“‚ Loading CSV file...")
df = pd.read_csv(csv_path, encoding="utf-8-sig")
df.columns = df.columns.str.strip().str.lower()
# Validate required columns
required = {"title", "abstract"}
found = set(df.columns)
missing = required - found
if missing:
raise ValueError(
f"Missing required columns: {missing}\n"
f"Found columns: {list(df.columns)}\n"
f"Please ensure your CSV has 'title' and 'abstract' columns."
)
# Drop rows where both fields are empty
df = df.dropna(subset=["title", "abstract"], how="all")
df["title"] = df["title"].fillna("")
df["abstract"] = df["abstract"].fillna("")
if len(df) == 0:
raise ValueError("CSV has no valid rows with title or abstract data.")
self.df = df
self.log(f"βœ… Loaded {len(df)} papers | Columns: {list(df.columns)}")
return df
# ── Full Pipeline ─────────────────────────────────────────────────────
def run_pipeline(self, csv_path: str, progress_callback=None) -> dict:
"""Execute the full 9-step analysis pipeline.
Args:
csv_path: Path to the uploaded CSV file.
progress_callback: Optional Gradio progress function for UI updates.
Returns:
Summary dict with topic counts and mapping statistics.
"""
def update(progress_val: float, msg: str) -> None:
self.log(msg)
if progress_callback:
try:
progress_callback(progress_val, desc=msg)
except Exception:
pass
# ── 1. Load & Validate ───────────────────────────────────────
update(0.05, "πŸ“‚ Step 1/9: Loading CSV...")
self.load_and_validate(csv_path)
update(0.10, f"βœ… Step 1/9: Loaded {len(self.df)} papers")
# ── 2. Preprocess ────────────────────────────────────────────
update(0.12, "πŸ”„ Step 2/9: Preprocessing text...")
self.df = preprocess_dataframe(self.df)
n_et = sum(1 for t in self.df["clean_title"] if not t.strip())
n_ea = sum(1 for t in self.df["clean_abstract"] if not t.strip())
update(0.18, f"βœ… Step 2/9: Preprocessed ({n_et} empty titles, {n_ea} empty abstracts)")
# ── 3. Topic Model on Titles ─────────────────────────────────
update(0.20, "πŸ”„ Step 3/9: Running NMF on titles (target: 50 topics)...")
title_texts = [t for t in self.df["clean_title"].tolist() if t.strip()]
if len(title_texts) < 5:
raise ValueError(
f"Only {len(title_texts)} non-empty titles after cleaning. "
f"Need at least 5 papers with valid titles."
)
title_matrix, title_vectorizer = vectorize_texts(
title_texts, max_features=3000
)
n_title_target = min(50, title_matrix.shape[1] - 1, len(title_texts) - 1)
n_title_target = max(n_title_target, 10)
title_model, n_title_actual = run_topic_model(
title_matrix, n_topics=n_title_target, method="nmf"
)
self.title_topics = extract_keywords(title_model, title_vectorizer, n_words=10)
# Assign IDs (1-based) and source tag
for i, t in enumerate(self.title_topics):
t["topic_id"] = i + 1
t["source"] = "title"
update(0.35, f"βœ… Step 3/9: Extracted {len(self.title_topics)} title topics")
# ── 4. Topic Model on Abstracts ──────────────────────────────
update(0.37, "πŸ”„ Step 4/9: Running NMF on abstracts (target: 50 topics)...")
abstract_texts = [t for t in self.df["clean_abstract"].tolist() if t.strip()]
if len(abstract_texts) < 5:
raise ValueError(
f"Only {len(abstract_texts)} non-empty abstracts after cleaning. "
f"Need at least 5 papers with valid abstracts."
)
abstract_matrix, abstract_vectorizer = vectorize_texts(
abstract_texts, max_features=5000
)
# Aim for 100 total topics
n_abs_target = max(50, 100 - len(self.title_topics))
n_abs_target = min(
n_abs_target,
abstract_matrix.shape[1] - 1,
len(abstract_texts) - 1,
)
n_abs_target = max(n_abs_target, 10)
abstract_model, n_abs_actual = run_topic_model(
abstract_matrix, n_topics=n_abs_target, method="nmf"
)
self.abstract_topics = extract_keywords(
abstract_model, abstract_vectorizer, n_words=10
)
# Offset IDs to continue after title topics
offset = len(self.title_topics)
for i, t in enumerate(self.abstract_topics):
t["topic_id"] = offset + i + 1
t["source"] = "abstract"
update(0.50, f"βœ… Step 4/9: Extracted {len(self.abstract_topics)} abstract topics")
# ── 5. Combine & Label ───────────────────────────────────────
self.all_topics = self.title_topics + self.abstract_topics
total = len(self.all_topics)
update(0.52, f"πŸ”„ Step 5/9: Labelling {total} topics...")
self.all_topics = label_topics_batch(
self.all_topics,
batch_size=10,
api_key=self.api_key,
provider=self.provider,
)
# Sync back to title/abstract lists
self.title_topics = [t for t in self.all_topics if t["source"] == "title"]
self.abstract_topics = [t for t in self.all_topics if t["source"] == "abstract"]
llm_used = any(
t.get("label", "") != generate_label_from_keywords(t["keywords"])
for t in self.all_topics[:3]
)
label_method = "LLM-enhanced" if llm_used else "heuristic"
update(0.65, f"βœ… Step 5/9: All {total} topics labelled ({label_method})")
# ── 6. PAJAIS Mapping ────────────────────────────────────────
update(0.67, "πŸ”„ Step 6/9: Mapping to PAJAIS taxonomy...")
self.taxonomy_map = map_to_taxonomy(self.all_topics)
n_mapped = sum(1 for m in self.taxonomy_map if m["status"] == "MAPPED")
n_novel = sum(1 for m in self.taxonomy_map if m["status"] == "NOVEL")
update(0.72, f"βœ… Step 6/9: {n_mapped} MAPPED, {n_novel} NOVEL")
# ── 7. Comparison CSV (C6) ───────────────────────────────────
update(0.74, "πŸ”„ Step 7/9: Generating comparison.csv (C6)...")
self.comparison_df = compare_title_abstract_themes(
self.title_topics, self.abstract_topics
)
self.comparison_df.to_csv("comparison.csv", index=False, encoding="utf-8-sig")
update(0.78, "βœ… Step 7/9: comparison.csv saved")
# ── 8. Taxonomy Map JSON (C7) ────────────────────────────────
update(0.80, "πŸ”„ Step 8/9: Saving taxonomy_map.json (C7)...")
taxonomy_json = {
"metadata": {
"total_topics": len(self.all_topics),
"title_topics": len(self.title_topics),
"abstract_topics": len(self.abstract_topics),
"mapped_count": n_mapped,
"novel_count": n_novel,
"taxonomy_used": "PAJAIS 25-Category",
},
"mappings": self.taxonomy_map,
"taxonomy_categories": PAJAIS_TAXONOMY,
}
Path("taxonomy_map.json").write_text(
json.dumps(taxonomy_json, indent=2, ensure_ascii=False),
encoding="utf-8",
)
update(0.83, "βœ… Step 8/9: taxonomy_map.json saved")
# ── 9. Narrative + Reflection + Prompts ──────────────────────
update(0.85, "πŸ”„ Step 9/9: Generating narrative, reflection & prompts...")
# Build summary strings for generation prompts
top_themes = self.all_topics[:20]
themes_summary = "\n".join(
f" - [{t['source'].upper()}] Topic {t['topic_id']}: {t['label']} "
f"(keywords: {', '.join(t['keywords'][:5])})"
for t in top_themes
)
mapped_cats = {
m["pajais_category"]
for m in self.taxonomy_map
if m["status"] == "MAPPED"
}
gaps = [cat for cat in PAJAIS_TAXONOMY if cat not in mapped_cats]
taxonomy_gaps = ", ".join(gaps) if gaps else "All categories covered"
# ── Narrative (C8)
self.narrative = generate_narrative(
themes_summary, taxonomy_gaps, len(self.df),
self.api_key, self.provider,
)
Path("narrative.txt").write_text(self.narrative, encoding="utf-8")
update(0.90, f"βœ… narrative.txt saved ({len(self.narrative.split())} words)")
# ── Reflection (C10)
comparison_summary = (
f"Title-based analysis produced {len(self.title_topics)} topics. "
f"Abstract-based analysis produced {len(self.abstract_topics)} topics. "
f"Total: {len(self.all_topics)} unique topics generated. "
f"PAJAIS mapping: {n_mapped} MAPPED, {n_novel} NOVEL."
)
self.reflection = generate_reflection(
themes_summary, comparison_summary,
self.api_key, self.provider,
)
Path("reflection.txt").write_text(self.reflection, encoding="utf-8")
update(0.95, f"βœ… reflection.txt saved ({len(self.reflection.split())} words)")
# ── Prompts (C9)
save_prompts("prompts.txt")
update(0.97, "βœ… prompts.txt saved (C9)")
# ── Done ─────────────────────────────────────────────────────
summary = (
f"\n{'=' * 50}\n"
f"βœ… PIPELINE COMPLETE\n"
f"{'=' * 50}\n"
f"πŸ“Š Total topics: {total} "
f"({len(self.title_topics)} title + {len(self.abstract_topics)} abstract)\n"
f"πŸ—ΊοΈ PAJAIS mapping: {n_mapped} MAPPED, {n_novel} NOVEL\n"
f"πŸ“ Output files: comparison.csv, taxonomy_map.json, "
f"narrative.txt, reflection.txt, prompts.txt"
)
update(1.0, summary)
return {
"total_topics": total,
"title_topics": len(self.title_topics),
"abstract_topics": len(self.abstract_topics),
"mapped": n_mapped,
"novel": n_novel,
}
# ── Result Accessors ──────────────────────────────────────────────────
def get_review_table(self) -> pd.DataFrame:
"""Return the review table as a DataFrame (C4).
Columns: topic_id, source, keywords, label
"""
if not self.all_topics:
return pd.DataFrame(columns=["topic_id", "source", "keywords", "label"])
rows = [
{
"topic_id": t["topic_id"],
"source": t.get("source", ""),
"keywords": t.get("keyword_str", ""),
"label": t.get("label", ""),
}
for t in self.all_topics
]
return pd.DataFrame(rows)
def get_mapping_table(self) -> pd.DataFrame:
"""Return the PAJAIS mapping table as a DataFrame (C5).
Columns: topic_id, source, label, pajais_category, status, confidence
"""
if not self.taxonomy_map:
return pd.DataFrame(
columns=["topic_id", "source", "label",
"pajais_category", "status", "confidence"]
)
return pd.DataFrame(self.taxonomy_map)
def get_download_files(self) -> list[str]:
"""Return absolute paths to all generated output files."""
files: list[str] = []
for fname in [
"comparison.csv",
"taxonomy_map.json",
"narrative.txt",
"reflection.txt",
"prompts.txt",
]:
p = Path(fname)
if p.exists():
files.append(str(p.resolve()))
return files