topic-model / agent.py
reyansh2005's picture
Upload 7 files
24ae147 verified
"""
agent.py β€” TopicAgent orchestrates the end-to-end topic modeling workflow.
This module defines the TopicAgent class, which:
1. Loads and validates the CSV dataset.
2. Preprocesses text for Titles and Abstracts separately.
3. Runs topic modeling on each corpus (β‰₯100 topics guaranteed).
4. Generates human-readable labels for every topic.
5. Compares dominant themes across Title and Abstract topics.
6. Produces a taxonomy map (MAPPED / NOVEL classification).
7. Exports structured outputs: topics table, comparison CSV, taxonomy JSON.
Usage:
agent = TopicAgent(csv_path="dataset.csv")
results = agent.run()
"""
import os
import json
import logging
from dataclasses import dataclass, field
from typing import Dict, Any, Optional
import pandas as pd
from tools import (
load_csv,
preprocess_text,
run_topic_modeling,
generate_labels,
compare_themes,
create_taxonomy_map,
)
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Structured result container
# ---------------------------------------------------------------------------
@dataclass
class AgentResult:
"""Container for all outputs produced by the TopicAgent."""
# Core dataframes
title_topics: pd.DataFrame = field(default_factory=pd.DataFrame)
abstract_topics: pd.DataFrame = field(default_factory=pd.DataFrame)
combined_topics: pd.DataFrame = field(default_factory=pd.DataFrame)
comparison: pd.DataFrame = field(default_factory=pd.DataFrame)
# Taxonomy map (dict serialisable to JSON)
taxonomy_map: Dict[str, Any] = field(default_factory=dict)
# Execution metadata
status: str = "pending"
steps_completed: list = field(default_factory=list)
errors: list = field(default_factory=list)
# File paths of exported artefacts
exported_files: Dict[str, str] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# TopicAgent
# ---------------------------------------------------------------------------
class TopicAgent:
"""
Orchestrates the research-paper topic modeling pipeline.
Parameters
----------
csv_path : str
Path to the input CSV file.
output_dir : str
Directory to write output files.
min_topics : int
Minimum number of topics to generate per source (default 100).
use_llm_labels : bool
Whether to use Groq LLM for label generation.
groq_api_key : str, optional
API key for Groq (used only when use_llm_labels is True).
"""
def __init__(
self,
csv_path: str,
output_dir: str = "outputs",
min_topics: int = 100,
use_llm_labels: bool = False,
groq_api_key: Optional[str] = None,
):
self.csv_path = csv_path
self.output_dir = output_dir
self.min_topics = min_topics
self.use_llm_labels = use_llm_labels
self.groq_api_key = groq_api_key
# Ensure output directory exists
os.makedirs(self.output_dir, exist_ok=True)
self._result = AgentResult()
# -----------------------------------------------------------------
# Public interface
# -----------------------------------------------------------------
def run(self) -> AgentResult:
"""
Execute the full pipeline step by step.
Returns
-------
AgentResult
Structured results including all DataFrames, taxonomy, and file paths.
"""
logger.info("=" * 60)
logger.info("TopicAgent β€” Starting pipeline")
logger.info("=" * 60)
try:
# Step 1: Load CSV
self._step_load_csv()
# Step 2: Preprocess text
self._step_preprocess()
# Step 3: Topic modeling on Titles
self._step_model_titles()
# Step 4: Topic modeling on Abstracts
self._step_model_abstracts()
# Step 5: Generate labels
self._step_generate_labels()
# Step 6: Build combined topics table
self._step_combine_topics()
# Step 7: Compare themes
self._step_compare_themes()
# Step 8: Create taxonomy map
self._step_taxonomy_map()
# Step 9: Export outputs
self._step_export()
self._result.status = "success"
logger.info("Pipeline completed successfully.")
except Exception as exc:
self._result.status = "failed"
self._result.errors.append(str(exc))
logger.error("Pipeline failed: %s", exc, exc_info=True)
return self._result
# -----------------------------------------------------------------
# Pipeline steps
# -----------------------------------------------------------------
def _step_load_csv(self):
"""Step 1 β€” Ingest CSV dataset."""
logger.info("Step 1/9: Loading CSV …")
self._df = load_csv(self.csv_path)
self._result.steps_completed.append("load_csv")
logger.info(" β†’ %d papers loaded.", len(self._df))
def _step_preprocess(self):
"""Step 2 β€” Preprocess Title and Abstract text."""
logger.info("Step 2/9: Preprocessing text …")
self._titles_clean = preprocess_text(self._df["Title"].tolist())
self._abstracts_clean = preprocess_text(self._df["Abstract"].tolist())
self._result.steps_completed.append("preprocess_text")
logger.info(" β†’ Titles preprocessed: %d docs", len(self._titles_clean))
logger.info(" β†’ Abstracts preprocessed: %d docs", len(self._abstracts_clean))
def _step_model_titles(self):
"""Step 3 β€” Topic modeling on Titles."""
logger.info("Step 3/9: Topic modeling on Titles …")
self._title_topics_df, self._title_model = run_topic_modeling(
self._titles_clean,
source_label="Titles",
min_topics=self.min_topics,
)
self._result.steps_completed.append("topic_modeling_titles")
logger.info(" β†’ %d title topics discovered.", len(self._title_topics_df))
def _step_model_abstracts(self):
"""Step 4 β€” Topic modeling on Abstracts."""
logger.info("Step 4/9: Topic modeling on Abstracts …")
self._abstract_topics_df, self._abstract_model = run_topic_modeling(
self._abstracts_clean,
source_label="Abstracts",
min_topics=self.min_topics,
)
self._result.steps_completed.append("topic_modeling_abstracts")
logger.info(" β†’ %d abstract topics discovered.", len(self._abstract_topics_df))
def _step_generate_labels(self):
"""Step 5 β€” Generate human-readable labels."""
logger.info("Step 5/9: Generating topic labels …")
self._title_topics_df = generate_labels(
self._title_topics_df,
use_llm=self.use_llm_labels,
groq_api_key=self.groq_api_key,
)
self._abstract_topics_df = generate_labels(
self._abstract_topics_df,
use_llm=self.use_llm_labels,
groq_api_key=self.groq_api_key,
)
self._result.title_topics = self._title_topics_df
self._result.abstract_topics = self._abstract_topics_df
self._result.steps_completed.append("generate_labels")
logger.info(" β†’ Labels generated for all topics.")
def _step_combine_topics(self):
"""Step 6 β€” Combine title and abstract topics into one table."""
logger.info("Step 6/9: Building combined topics table …")
combined = pd.concat(
[self._title_topics_df, self._abstract_topics_df],
ignore_index=True,
)
combined["global_id"] = range(len(combined))
self._result.combined_topics = combined
self._result.steps_completed.append("combine_topics")
logger.info(" β†’ Combined table: %d topics total.", len(combined))
def _step_compare_themes(self):
"""Step 7 β€” Compare title vs abstract themes."""
logger.info("Step 7/9: Comparing title vs abstract themes …")
comparison = compare_themes(self._title_topics_df, self._abstract_topics_df)
self._result.comparison = comparison
self._result.steps_completed.append("compare_themes")
logger.info(" β†’ Comparison table: %d rows.", len(comparison))
def _step_taxonomy_map(self):
"""Step 8 β€” Create taxonomy map (MAPPED / NOVEL)."""
logger.info("Step 8/9: Building taxonomy map …")
# Use the combined topics for taxonomy
taxonomy = create_taxonomy_map(self._result.combined_topics)
self._result.taxonomy_map = taxonomy
self._result.steps_completed.append("create_taxonomy_map")
logger.info(
" β†’ MAPPED: %d, NOVEL: %d",
taxonomy["metadata"]["mapped_count"],
taxonomy["metadata"]["novel_count"],
)
def _step_export(self):
"""Step 9 β€” Export all outputs to disk."""
logger.info("Step 9/9: Exporting outputs …")
# (a) Combined topics table CSV
topics_path = os.path.join(self.output_dir, "topics_table.csv")
self._result.combined_topics.to_csv(topics_path, index=False)
self._result.exported_files["topics_table"] = topics_path
logger.info(" β†’ Saved: %s", topics_path)
# (b) Comparison CSV
comparison_path = os.path.join(self.output_dir, "comparison.csv")
self._result.comparison.to_csv(comparison_path, index=False)
self._result.exported_files["comparison"] = comparison_path
logger.info(" β†’ Saved: %s", comparison_path)
# (c) Taxonomy map JSON
taxonomy_path = os.path.join(self.output_dir, "taxonomy_map.json")
with open(taxonomy_path, "w", encoding="utf-8") as f:
json.dump(self._result.taxonomy_map, f, indent=2, ensure_ascii=False)
self._result.exported_files["taxonomy_map"] = taxonomy_path
logger.info(" β†’ Saved: %s", taxonomy_path)
# (d) Title topics CSV
title_path = os.path.join(self.output_dir, "title_topics.csv")
self._result.title_topics.to_csv(title_path, index=False)
self._result.exported_files["title_topics"] = title_path
logger.info(" β†’ Saved: %s", title_path)
# (e) Abstract topics CSV
abstract_path = os.path.join(self.output_dir, "abstract_topics.csv")
self._result.abstract_topics.to_csv(abstract_path, index=False)
self._result.exported_files["abstract_topics"] = abstract_path
logger.info(" β†’ Saved: %s", abstract_path)
self._result.steps_completed.append("export")
logger.info(" β†’ All outputs exported successfully.")