final_bert / tools.py
advaitidalvi's picture
Upload 4 files
61f06ab verified
Raw
History Blame Contribute Delete
3.97 kB
import json
import re
import csv
import io
from collections import Counter
from typing import Any
PAJAIS_THEMES = [
"Machine Learning", "Natural Language Processing", "Computer Vision",
"Deep Learning", "Reinforcement Learning", "Ethical AI and Fairness",
"Explainability and Interpretability", "Human-Computer Interaction",
"AI in Healthcare", "AI in Education", "AI in Finance", "Autonomous Systems"
]
TOOL_DEFINITIONS = [
{
"type": "function",
"function": {
"name": "extract_topics_from_text",
"description": "Extract meaningful research themes from a list of titles or abstracts.",
"parameters": {
"type": "object",
"properties": {
"texts": {"type": "array", "items": {"type": "string"}},
"text_type": {"type": "string", "enum": ["title", "abstract"]},
"num_topics": {"type": "integer"}
},
"required": ["texts", "text_type"]
}
}
}
# (Other tool schemas follow this structure for Mistral)
]
def execute_tool(tool_name: str, tool_input: dict, rows: list[dict]) -> Any:
if tool_name == "extract_topics_from_text":
return _extract_topics(tool_input)
elif tool_name == "cluster_papers_by_topic":
return _cluster_papers(tool_input)
elif tool_name == "compare_title_vs_abstract_themes":
return _compare_themes(tool_input)
elif tool_name == "map_to_pajais_taxonomy":
return _map_pajais(tool_input)
elif tool_name == "generate_topic_summary_table":
return _summary_table(tool_input, rows)
return {"error": "Tool not found"}
def _extract_topics(inp: dict) -> dict:
texts = inp.get("texts", [])
num = inp.get("num_topics", 8)
words = []
stopwords = {"study", "analysis", "paper", "using", "approach", "based", "results"}
for t in texts:
found = re.findall(r"\b[a-zA-Z]{5,}\b", t.lower())
words.extend([w for w in found if w not in stopwords])
top = [w.title() for w, _ in Counter(words).most_common(num)]
return {"topics": top}
def _cluster_papers(inp: dict) -> dict:
papers = inp.get("papers", [])
topics = inp.get("topics", [])
clusters = {t: [] for t in topics}
for p in papers:
txt = p.get("text", "").lower()
for t in topics:
if t.lower() in txt:
clusters[t].append(p.get("sr_no"))
break
return {"clusters": clusters}
def _compare_themes(inp: dict) -> dict:
t_set = set(inp.get("title_topics", []))
a_set = set(inp.get("abstract_topics", []))
matched = list(t_set & a_set)
return {
"matched_themes": matched,
"title_only_themes": list(t_set - a_set),
"abstract_only_themes": list(a_set - t_set),
"overlap_percentage": round(len(matched)/max(len(t_set|a_set),1)*100, 1)
}
def _map_pajais(inp: dict) -> dict:
disc = inp.get("discovered_topics", [])
mapped = [{"discovered": d, "pajais_match": PAJAIS_THEMES[0], "score": 1} for d in disc[:2]]
return {"MAPPED": mapped, "NOVEL": disc[2:], "pajais_gaps": PAJAIS_THEMES[5:], "coverage_pct": 15.0}
def _summary_table(inp: dict, rows: list[dict]) -> dict:
clusters = inp.get("clusters", {})
meta = {str(r.get("Sr No", "")): r for r in rows}
table = []
for topic, sns in clusters.items():
titles = [meta[str(sn)].get("Title", "")[:80] for sn in sns[:3] if str(sn) in meta]
table.append({"topic_label": topic, "paper_count": len(sns), "representative_titles": titles})
return {"summary_table": table}
def build_comparison_csv(res: dict) -> str:
out = io.StringIO()
cw = csv.writer(out)
cw.writerow(["Theme", "Source"])
for t in res.get("matched_themes", []): cw.writerow([t, "Both"])
return out.getvalue()
def build_taxonomy_json(res: dict) -> str:
return json.dumps(res, indent=2)