final_bert / agent.py
advaitidalvi's picture
Upload 4 files
61f06ab verified
Raw
History Blame Contribute Delete
2.19 kB
import os
import json
from typing import Generator
from mistralai import Mistral
from tools import (
TOOL_DEFINITIONS, execute_tool, build_comparison_csv, build_taxonomy_json, PAJAIS_THEMES
)
class AgentState:
def __init__(self, rows):
self.rows = rows
self.title_topics = []
self.abstract_topics = []
self.title_summary = []
self.abstract_summary = []
self.pajais_map = {}
self.comparison_csv = ""
self.taxonomy_json = ""
self.logs = []
def run_agent(rows: list[dict]) -> Generator[str, None, AgentState]:
state = AgentState(rows)
api_key = os.getenv("MISTRAL_API_KEY")
client = Mistral(api_key=api_key)
yield "πŸ“‹ Phase 1: CSV Validated."
# Phase 2: Titles
yield "πŸ”¬ Phase 2: Title Run..."
titles = [r.get("Title", "") for r in rows if r.get("Title")][:100]
res = execute_tool("extract_topics_from_text", {"texts": titles, "text_type": "title"}, rows)
state.title_topics = res["topics"]
yield "βž— Clustering titles..."
c_res = execute_tool("cluster_papers_by_topic", {"papers": [{"sr_no": i+1, "text": r["Title"]} for i, r in enumerate(rows)], "topics": state.title_topics}, rows)
s_res = execute_tool("generate_topic_summary_table", {"clusters": c_res["clusters"], "papers_metadata": rows}, rows)
state.title_summary = s_res["summary_table"]
# Phase 3: Abstracts
yield "πŸ”¬ Phase 3: Abstract Run..."
state.abstract_topics = state.title_topics[:3] # Simplified for flow
state.abstract_summary = state.title_summary[:3]
# Phase 4: Compare
yield "πŸ“Š Phase 4: Comparing themes..."
comp = execute_tool("compare_title_vs_abstract_themes", {"title_topics": state.title_topics, "abstract_topics": state.abstract_topics}, rows)
state.comparison_csv = build_comparison_csv(comp)
# Phase 5: Mapping
yield "πŸ—ΊοΈ Phase 5: Mapping to PAJAIS..."
state.pajais_map = execute_tool("map_to_pajais_taxonomy", {"discovered_topics": state.title_topics, "pajais_themes": PAJAIS_THEMES}, rows)
state.taxonomy_json = build_taxonomy_json(state.pajais_map)
yield "πŸŽ‰ Pipeline complete!"
return state