gaia-agent / tools.py
mrtom17's picture
Update tools.py
2423b7e verified
from pydantic import BaseModel, Field
from typing import Optional
import math
import requests
from langchain_core.tools import tool
import os
# --- Calculator Tool ---
class CalculatorInput(BaseModel):
expression: str = Field(..., description="A mathematical expression to evaluate, e.g. '2 + 2 * 3'.")
@tool(args_schema=CalculatorInput, return_direct=True)
def calculator_tool(expression: str) -> str:
"""Evaluate a mathematical expression, e.g. '2 + 2 * 3'."""
try:
# WARNING: eval is dangerous in production! Here we use it for simplicity, but in real apps use a safe parser.
result = eval(expression, {"__builtins__": None, "math": math}, {})
return str(result)
except Exception as e:
return f"Error: {e}"
# --- Wikipedia Search Tool ---
class WikipediaSearchInput(BaseModel):
query: str = Field(..., description="The search query for Wikipedia.")
sentences: Optional[int] = Field(3, description="Number of sentences to return from the summary.")
# We'll use the wikipedia library for this tool
try:
import wikipedia
except ImportError:
wikipedia = None
@tool(args_schema=WikipediaSearchInput, return_direct=True)
def wikipedia_search_tool(query: str, sentences: int = 3) -> str:
"""Search Wikipedia for a summary of a topic."""
if wikipedia is None:
return "Wikipedia library not installed. Please install it with 'pip install wikipedia'."
try:
summary = wikipedia.summary(query, sentences=sentences)
return summary
except Exception as e:
return f"Wikipedia search error: {e}"
# --- Python Interpreter Tool ---
class PythonInterpreterInput(BaseModel):
code: str = Field(..., description="Python code to execute. Should print or return the answer.")
@tool(args_schema=PythonInterpreterInput, return_direct=True)
def python_interpreter_tool(code: str) -> str:
"""Execute Python code and return the result. Use variable 'result' or print output."""
import io
import contextlib
local_vars = {}
output = io.StringIO()
try:
with contextlib.redirect_stdout(output):
exec(code, {"__builtins__": {}}, local_vars)
# If code defines a variable 'result', return it; else return stdout
if 'result' in local_vars:
return str(local_vars['result'])
result_output = output.getvalue().strip()
return result_output if result_output else "(No output)"
except Exception as e:
return f"Python execution error: {e}"
# --- Unit Conversion Tool ---
class UnitConversionInput(BaseModel):
value: float = Field(..., description="The numeric value to convert.")
from_unit: str = Field(..., description="The unit to convert from, e.g. 'meters'.")
to_unit: str = Field(..., description="The unit to convert to, e.g. 'feet'.")
# Simple conversion table for demonstration
CONVERSION_FACTORS = {
("meters", "feet"): 3.28084,
("feet", "meters"): 0.3048,
("kilograms", "pounds"): 2.20462,
("pounds", "kilograms"): 0.453592,
("celsius", "fahrenheit"): lambda c: c * 9/5 + 32,
("fahrenheit", "celsius"): lambda f: (f - 32) * 5/9,
}
@tool(args_schema=UnitConversionInput, return_direct=True)
def unit_conversion_tool(value: float, from_unit: str, to_unit: str) -> str:
"""Convert between units (e.g., meters to feet, celsius to fahrenheit)."""
key = (from_unit.lower(), to_unit.lower())
try:
factor = CONVERSION_FACTORS[key]
if callable(factor):
result = factor(value)
else:
result = value * factor
return f"{value} {from_unit} = {result} {to_unit}"
except Exception:
return f"Conversion from {from_unit} to {to_unit} not supported."
# --- Date/Time Calculation Tool ---
from datetime import datetime, timedelta
class DateTimeCalcInput(BaseModel):
base_date: str = Field(..., description="The starting date in YYYY-MM-DD format. If blank, use today.")
delta_days: int = Field(..., description="Number of days to add (positive) or subtract (negative).")
@tool(args_schema=DateTimeCalcInput, return_direct=True)
def date_time_calc_tool(base_date: str, delta_days: int) -> str:
"""Add or subtract days from a date (YYYY-MM-DD)."""
try:
base = datetime.strptime(base_date, "%Y-%m-%d") if base_date else datetime.now()
new_date = base + timedelta(days=delta_days)
return new_date.strftime("%Y-%m-%d")
except Exception as e:
return f"Date calculation error: {e}"
# --- Tavily Search Tool ---
try:
from tavily import TavilyClient
except ImportError:
TavilyClient = None
class TavilySearchInput(BaseModel):
query: str = Field(..., description="The search query to look up on the web.")
num_results: int = Field(3, description="Number of results to return.")
@tool(args_schema=TavilySearchInput, return_direct=True)
def tavily_search_tool(query: str, num_results: int = 3) -> str:
"""Search the web for up-to-date information using Tavily API (official client)."""
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
return "Tavily API key not set. Please set TAVILY_API_KEY in your environment."
if TavilyClient is None:
return "Tavily Python client not installed. Please install it with 'pip install tavily'."
try:
tavily_client = TavilyClient(api_key=api_key)
response = tavily_client.search(query, max_results=num_results)
if isinstance(response, dict):
if response.get("answer"):
return response["answer"]
elif response.get("results"):
# Use 'content' instead of 'snippet'
contents = [r.get("content", "") for r in response["results"][:num_results]]
contents = [c for c in contents if c.strip()]
return "\n".join(contents) if contents else str(response)
else:
return str(response)
else:
return str(response)
except Exception as e:
return f"Tavily search error: {e}"
# --- Audio Transcription Tool ---
class AudioTranscriptionInput(BaseModel):
file_path: str = Field(..., description="Path to the audio file to transcribe.")
@tool(args_schema=AudioTranscriptionInput, return_direct=True)
def audio_transcription_tool(file_path: str) -> str:
"""Transcribe an audio file using OpenAI's new API (>=1.0.0, gpt-4o-transcribe)."""
try:
import openai
import os
api_key = os.getenv("OPENAI_API_KEY")
client = openai.OpenAI(api_key=api_key)
with open(file_path, "rb") as audio_file:
transcript = client.audio.transcriptions.create(
file=audio_file,
model="gpt-4o-transcribe",
response_format="text"
)
return transcript
except Exception as e:
return f"Audio transcription error: {e}"
# --- Image Captioning Tool ---
class ImageCaptioningInput(BaseModel):
file_path: str = Field(..., description="Path to the image file to caption.")
@tool(args_schema=ImageCaptioningInput, return_direct=True)
def image_captioning_tool(file_path: str) -> str:
"""Generate a caption for an image using BLIP from transformers (requires transformers and torch)."""
try:
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
image = Image.open(file_path).convert("RGB")
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
out = model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"Image captioning error: {e}"
# --- Python File Reader Tool ---
class PythonFileReaderInput(BaseModel):
file_path: str = Field(..., description="Path to the Python file to read.")
max_lines: Optional[int] = Field(None, description="Maximum number of lines to read from the file.")
@tool(args_schema=PythonFileReaderInput, return_direct=True)
def python_file_reader_tool(file_path: str, max_lines: Optional[int] = None) -> str:
"""Read and return the content of a Python file (optionally limited to max_lines)."""
try:
with open(file_path, "r", encoding="utf-8") as f:
if max_lines is not None:
lines = [next(f) for _ in range(max_lines)]
return "".join(lines)
else:
return f.read()
except Exception as e:
return f"Python file read error: {e}"
# --- Data Analysis Tool ---
class DataAnalysisInput(BaseModel):
file_path: str = Field(..., description="Path to the Excel or CSV file to analyze.")
instruction: str = Field(..., description="Analysis instruction, e.g. 'summary', 'head', 'describe', or a column name.")
@tool(args_schema=DataAnalysisInput, return_direct=True)
def data_analysis_tool(file_path: str, instruction: str) -> str:
"""Analyze an Excel or CSV file using pandas. Instruction can be 'summary', 'head', 'describe', or a column name."""
import pandas as pd
import os
try:
if not os.path.exists(file_path):
return f"File not found: {file_path}"
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
df = pd.read_excel(file_path)
else:
return "Unsupported file type. Only .csv, .xlsx, and .xls are supported."
instruction_lower = instruction.strip().lower()
if instruction_lower == 'summary':
return str(df.info())
elif instruction_lower == 'head':
return df.head().to_string()
elif instruction_lower == 'describe':
return df.describe().to_string()
elif instruction in df.columns:
return df[instruction].to_string()
else:
return f"Unknown instruction or column: {instruction}"
except Exception as e:
return f"Data analysis error: {e}"
# --- Tool List for LangGraph/LangChain ---
TOOLS = [
calculator_tool,
tavily_search_tool,
wikipedia_search_tool,
python_interpreter_tool,
unit_conversion_tool,
date_time_calc_tool,
audio_transcription_tool,
image_captioning_tool,
python_file_reader_tool,
data_analysis_tool,
]