acfp / agent_utils.py
Mohammad Haghir
update
2f7b616
from langchain_community.document_loaders import WikipediaLoader
from langchain.tools import tool
import arxiv
import os
from tavily import TavilyClient
import requests
from PIL import Image
from io import BytesIO
import pandas as pd
@tool
def wiki_ret(question: str) -> str:
""" Retrieve docs from wikipedia """
print("wiki")
# Search
search_docs = WikipediaLoader(query=question,
load_max_docs=2).load()
# Format
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
]
)
return {"context": formatted_search_docs}
@tool
def arxiv_ret(query: str, max_results: int = 3) -> str:
"""Search arXiv for a given query and return a summary of top results."""
print("arxiv")
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
results = []
for result in search.results():
summary = result.summary.replace("\n", " ")
results.append(f"Title: {result.title}\nAuthors: {', '.join(a.name for a in result.authors)}\nSummary: {summary}\nURL: {result.entry_id}\n")
if not results:
return "No relevant papers found."
return "\n\n".join(results)
@tool
def tavily_ret(query: str, max_results: int = 3) -> str:
"""Use Tavily to retrieve web-based information about a topic."""
print("tavily")
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
return "Tavily API key not found."
client = TavilyClient(api_key=api_key)
results = client.search(query=query, search_depth="basic", max_results=max_results)
if not results["results"]:
return "No relevant information found."
summaries = []
for item in results["results"]:
summaries.append(f"Title: {item['title']}\nURL: {item['url']}\nSnippet: {item['content']}\n")
return "\n\n".join(summaries)
@tool
def handle_file_tool(input: dict) -> str:
"""Reads the file attached to a question"""
question = input.get("question", "")
file_url = input.get("file", "")
print("file_url: ", file_url)
if not file_url:
return "No file provided."
response = requests.get(file_url)
if not response.ok:
return f"Failed to download file: {file_url}"
file_ext = file_url.split("?")[0].split(".")[-1].lower()
# IMAGE
if file_ext in ("png", "jpg", "jpeg"):
image = Image.open(BytesIO(response.content))
return f"Image file received for question: '{question}'. (Insert vision model here.)"
# PDF
elif file_ext == "pdf":
with open("temp.pdf", "wb") as f:
f.write(response.content)
return f"PDF file received. (Insert PDF parsing logic here.)"
# EXCEL
elif file_ext in ("xlsx", "xls"):
df = pd.read_excel(BytesIO(response.content))
summary = f"Excel data preview:\n{df.head().to_string(index=False)}"
return f"Question: {question}\n{summary}"
# PYTHON CODE
elif file_ext == "py":
code = response.content.decode("utf-8")
return f"Question: {question}\nPython code received:\n{code[:500]}..." # Limit preview
else:
return f"Unsupported file type: .{file_ext}"
@tool
def add(a: float, b: float):
"""calculate summation of two numbers"""
return a + b
@tool
def subtract(a: float, b: float):
"""calculate subtraction of two numbers"""
return a - b
@tool
def multiplication(a: float, b: float):
"""calculate multiplication of two numbers"""
return a * b
@tool
def division(a: float, b: float):
"""calculate division of two numbers"""
return a / b
@tool
def mode(a: float, b: float):
"""calculate remainder of two numbers"""
return a % b