topic-modelling / agent.py
vvinayakkk's picture
Initial clean commit with LFS
a1d17f8
import json
from typing import Any, Dict, Optional
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
import tools as pipeline_tools
_AGENT_CONTEXT: Dict[str, Any] = {
"csv_path": None,
"output_dir": ".",
}
def _topic_summary(result: pipeline_tools.AnalysisResult) -> Dict[str, Any]:
table = result.topic_table
top_labels = table.sort_values("Count", ascending=False)["Label"].head(10).tolist() if not table.empty else []
return {
"source": result.source_column,
"papers": result.paper_count,
"analysis_units": result.unit_count,
"topic_count": int(len(table)),
"top_labels": top_labels,
}
@tool
def run_abstract_analysis_tool() -> str:
"""Run Braun and Clarke thematic pipeline on the Abstract column."""
result = pipeline_tools.run_single_analysis(
mode="abstract",
csv_path=_AGENT_CONTEXT.get("csv_path"),
output_dir=_AGENT_CONTEXT.get("output_dir", "."),
)
payload = {
"status": "ok",
"summary": _topic_summary(result),
}
return json.dumps(payload, indent=2)
@tool
def run_title_analysis_tool() -> str:
"""Run Braun and Clarke thematic pipeline on the Title column."""
result = pipeline_tools.run_single_analysis(
mode="title",
csv_path=_AGENT_CONTEXT.get("csv_path"),
output_dir=_AGENT_CONTEXT.get("output_dir", "."),
)
payload = {
"status": "ok",
"summary": _topic_summary(result),
}
return json.dumps(payload, indent=2)
@tool
def run_full_pipeline_tool() -> str:
"""Run full end-to-end pipeline for abstract and title, then generate comparison and narrative files."""
result = pipeline_tools.run_full_pipeline(
csv_path=_AGENT_CONTEXT.get("csv_path"),
output_dir=_AGENT_CONTEXT.get("output_dir", "."),
)
payload = {
"status": "ok",
"csv_path": result["csv_path"],
"abstract_topics": int(len(result["abstract"].topic_table)),
"title_topics": int(len(result["title"].topic_table)),
"comparison_rows": int(len(result["comparison"])),
"files": result["files"],
}
return json.dumps(payload, indent=2)
@tool
def get_output_files_tool() -> str:
"""Get generated deliverable file paths."""
files = pipeline_tools.ensure_output_artifacts(_AGENT_CONTEXT.get("output_dir", "."))
return json.dumps(files, indent=2)
def _fallback_router(message: str, csv_path: Optional[str], output_dir: str) -> str:
lowered = message.lower()
if "full" in lowered or "end to end" in lowered or "pipeline" in lowered or "compare" in lowered:
result = pipeline_tools.run_full_pipeline(csv_path=csv_path, output_dir=output_dir)
return (
"Full pipeline complete. "
f"Abstract topics: {len(result['abstract'].topic_table)} | "
f"Title topics: {len(result['title'].topic_table)} | "
f"Comparison rows: {len(result['comparison'])}."
)
if "abstract" in lowered:
result = pipeline_tools.run_single_analysis(mode="abstract", csv_path=csv_path, output_dir=output_dir)
return (
"Abstract analysis complete. "
f"Identified {len(result.topic_table)} topics from {result.unit_count} cleaned analysis units."
)
if "title" in lowered:
result = pipeline_tools.run_single_analysis(mode="title", csv_path=csv_path, output_dir=output_dir)
return (
"Title analysis complete. "
f"Identified {len(result.topic_table)} topics from {result.unit_count} cleaned analysis units."
)
files = pipeline_tools.ensure_output_artifacts(output_dir)
return (
"I can run 'abstract analysis', 'title analysis', or 'full pipeline'. "
f"Current output files are available at: {files}"
)
def run_agent_command(message: str, csv_path: Optional[str] = None, output_dir: str = ".") -> str:
_AGENT_CONTEXT["csv_path"] = csv_path
_AGENT_CONTEXT["output_dir"] = output_dir
llm = pipeline_tools.create_groq_llm(temperature=0.1)
if llm is None:
return _fallback_router(message, csv_path=csv_path, output_dir=output_dir)
tools = [
run_abstract_analysis_tool,
run_title_analysis_tool,
run_full_pipeline_tool,
get_output_files_tool,
]
try:
react_agent = create_react_agent(llm, tools)
response = react_agent.invoke({"messages": [("user", message)]})
messages = response.get("messages", [])
if not messages:
return _fallback_router(message, csv_path=csv_path, output_dir=output_dir)
final_message = messages[-1]
content = getattr(final_message, "content", str(final_message))
return str(content)
except Exception:
return _fallback_router(message, csv_path=csv_path, output_dir=output_dir)