AheedTahir's picture
changed Llama model due to rate limit
88cb2f4
"""
GAIA Agent with Essential Tools for 30%+ Accuracy
Built with LangGraph and Groq LLM
"""
import os
import re
import json
from typing import Annotated
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
# Initialize LLM
def get_llm():
"""Get Groq LLM instance"""
return ChatGroq(
model="llama-3.1-8b-instant",
temperature=0,
max_tokens=8000,
timeout=60,
max_retries=2
)
# ============================================================================
# TOOL DEFINITIONS
# ============================================================================
@tool
def web_search(query: str) -> str:
"""
Search the web for current information using Tavily.
Use this for finding recent information, facts, statistics, or any data not in your training.
Args:
query: The search query string
Returns:
Search results as formatted text
"""
try:
tavily = TavilySearchResults(
max_results=5,
search_depth="advanced",
include_answer=True,
include_raw_content=False
)
results = tavily.invoke(query)
if not results:
return "No results found."
# Format results nicely
formatted = []
for i, result in enumerate(results, 1):
title = result.get('title', 'No title')
content = result.get('content', 'No content')
url = result.get('url', '')
formatted.append(f"Result {i}:\nTitle: {title}\nContent: {content}\nURL: {url}\n")
return "\n".join(formatted)
except Exception as e:
return f"Error searching web: {str(e)}"
@tool
def wikipedia_search(query: str) -> str:
"""
Search Wikipedia for encyclopedic information.
Use this for historical facts, biographies, scientific concepts, etc.
Args:
query: The Wikipedia search query
Returns:
Wikipedia article content
"""
try:
loader = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=4000)
docs = loader.load()
if not docs:
return f"No Wikipedia article found for '{query}'"
# Combine the documents
content = "\n\n---\n\n".join([doc.page_content for doc in docs])
return f"Wikipedia results for '{query}':\n\n{content}"
except Exception as e:
return f"Error searching Wikipedia: {str(e)}"
@tool
def calculate(expression: str) -> str:
"""
Evaluate a mathematical expression safely.
Supports basic arithmetic: +, -, *, /, //, %, **, parentheses.
Also supports common math functions: abs, round, min, max, sum.
Args:
expression: Mathematical expression as a string (e.g., "2 + 2", "sqrt(16)", "10 ** 2")
Returns:
The calculated result
"""
try:
# Import math for advanced functions
import math
# Create a safe namespace with math functions
safe_dict = {
'abs': abs, 'round': round, 'min': min, 'max': max, 'sum': sum,
'sqrt': math.sqrt, 'pow': pow, 'log': math.log, 'log10': math.log10,
'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
'pi': math.pi, 'e': math.e, 'ceil': math.ceil, 'floor': math.floor
}
# Clean the expression
expression = expression.strip()
# Evaluate safely
result = eval(expression, {"__builtins__": {}}, safe_dict)
return str(result)
except Exception as e:
return f"Error calculating '{expression}': {str(e)}"
@tool
def python_executor(code: str) -> str:
"""
Execute Python code safely for data processing and calculations.
Use this for complex calculations, data manipulation, or multi-step computations.
The code should print its output.
Args:
code: Python code to execute
Returns:
The output of the code execution
"""
try:
import io
import sys
import math
import json
from datetime import datetime, timedelta
# Capture stdout
old_stdout = sys.stdout
sys.stdout = buffer = io.StringIO()
# Create safe execution environment
safe_globals = {
'__builtins__': {
'print': print, 'len': len, 'range': range, 'str': str,
'int': int, 'float': float, 'list': list, 'dict': dict,
'set': set, 'tuple': tuple, 'sorted': sorted, 'sum': sum,
'min': min, 'max': max, 'abs': abs, 'round': round,
'enumerate': enumerate, 'zip': zip, 'map': map, 'filter': filter,
},
'math': math,
'json': json,
'datetime': datetime,
'timedelta': timedelta,
}
# Execute code
exec(code, safe_globals)
# Get output
sys.stdout = old_stdout
output = buffer.getvalue()
return output if output else "Code executed successfully (no output)"
except Exception as e:
sys.stdout = old_stdout
return f"Error executing code: {str(e)}"
@tool
def read_file(filepath: str) -> str:
"""
Read and return the contents of a file.
Supports text files, CSV, JSON, and basic file formats.
Args:
filepath: Path to the file to read
Returns:
File contents as string
"""
try:
# Check if file exists
if not os.path.exists(filepath):
return f"File not found: {filepath}"
# Read based on file extension
if filepath.endswith('.json'):
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return json.dumps(data, indent=2)
elif filepath.endswith('.csv'):
try:
import pandas as pd
df = pd.read_csv(filepath)
return f"CSV file with {len(df)} rows and {len(df.columns)} columns:\n\n{df.to_string()}"
except ImportError:
# Fallback if pandas not available
with open(filepath, 'r', encoding='utf-8') as f:
return f.read()
else:
# Read as text
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
return f"Error reading file '{filepath}': {str(e)}"
# ============================================================================
# SYSTEM PROMPT - GAIA Specific Instructions
# ============================================================================
GAIA_SYSTEM_PROMPT = """You are a helpful AI assistant designed to answer questions from the GAIA benchmark.
CRITICAL ANSWER FORMAT RULES:
1. For numbers: NO commas, NO units (unless explicitly requested)
- CORRECT: "1000" or "1000 meters" (if units requested)
- WRONG: "1,000" or "1000 meters" (if units not requested)
2. For text answers: No articles (a, an, the), no abbreviations
- CORRECT: "United States"
- WRONG: "The United States" or "USA"
3. For lists: Comma-separated with one space after each comma
- CORRECT: "apple, banana, orange"
- WRONG: "apple,banana,orange" or "apple, banana, orange."
4. For dates: Use the format specified in the question
- If not specified, use ISO format: YYYY-MM-DD
5. Be precise and concise - answer ONLY what is asked
APPROACH:
1. Read the question carefully and identify what information is needed
2. Use tools to gather information (web search, Wikipedia, calculations)
3. For multi-step questions, break down the problem and solve step by step
4. Verify your answer matches the format requirements above
5. Return ONLY the final answer in the correct format
AVAILABLE TOOLS:
- web_search: Search the internet for current information
- wikipedia_search: Search Wikipedia for encyclopedic knowledge
- calculate: Perform mathematical calculations
- python_executor: Execute Python code for complex computations
- read_file: Read files (CSV, JSON, text)
Remember: Your final response should be ONLY the answer in the correct format, nothing else.
"""
# ============================================================================
# AGENT GRAPH CONSTRUCTION
# ============================================================================
def build_graph():
"""Build the LangGraph agent with tools"""
# Initialize LLM
llm = get_llm()
# Define tools
tools = [
web_search,
wikipedia_search,
calculate,
python_executor,
read_file
]
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Define the assistant node
def assistant(state: MessagesState):
"""Assistant node that calls the LLM"""
messages = state["messages"]
# Add system message if not present
if not any(isinstance(msg, SystemMessage) for msg in messages):
messages = [SystemMessage(content=GAIA_SYSTEM_PROMPT)] + messages
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
# Build the graph
builder = StateGraph(MessagesState)
# Add nodes
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
# Add edges
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile with memory
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
return graph
# ============================================================================
# TESTING
# ============================================================================
if __name__ == "__main__":
"""Test the agent with sample questions"""
from langchain_core.messages import HumanMessage
# Build agent
print("Building agent...")
agent = build_graph()
# Test questions
test_questions = [
"What is 25 * 4 + 100?",
"Who was the first president of the United States?",
"Search for the population of Tokyo in 2024"
]
for i, question in enumerate(test_questions, 1):
print(f"\n{'='*60}")
print(f"Test {i}: {question}")
print('='*60)
try:
config = {"configurable": {"thread_id": f"test_{i}"}}
result = agent.invoke(
{"messages": [HumanMessage(content=question)]},
config=config
)
answer = result['messages'][-1].content
print(f"Answer: {answer}")
except Exception as e:
print(f"Error: {e}")