jarvis / agent_enhanced.py
jebaponselvasingh
first commit
0b90c85
"""
Enhanced GAIA Agent with LangGraph
Separate module for cleaner architecture and easier customization
"""
import os
import re
import json
import requests
import tempfile
from typing import TypedDict, Annotated, Sequence, Literal, Any
import operator
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_experimental.utilities import PythonREPL
import pandas as pd
# ============ STATE DEFINITION ============
class AgentState(TypedDict):
"""State maintained throughout the agent's execution."""
messages: Annotated[Sequence[BaseMessage], operator.add]
task_id: str
file_path: str | None
file_content: str | None
iteration_count: int
final_answer: str | None
# ============ TOOL DEFINITIONS ============
@tool
def web_search(query: str) -> str:
"""
Search the web using DuckDuckGo for current information.
Use this for questions about recent events, facts, statistics, or any information
that might have changed or that you're uncertain about.
Args:
query: The search query string
Returns:
Search results with relevant snippets
"""
try:
search = DuckDuckGoSearchResults(max_results=5, output_format="list")
results = search.run(query)
if isinstance(results, list):
formatted = []
for r in results:
if isinstance(r, dict):
formatted.append(f"Title: {r.get('title', 'N/A')}\nSnippet: {r.get('snippet', 'N/A')}\nLink: {r.get('link', 'N/A')}")
else:
formatted.append(str(r))
return "\n\n---\n\n".join(formatted)
return str(results)
except Exception as e:
return f"Search failed: {str(e)}. Try a different query or approach."
@tool
def python_executor(code: str) -> str:
"""
Execute Python code for calculations, data analysis, or any computational task.
You have access to standard libraries: math, statistics, datetime, json, re, collections.
Args:
code: Python code to execute. Print statements will show in output.
Returns:
The output/result of the code execution
"""
try:
repl = PythonREPL()
# Add common imports to the code
augmented_code = """
import math
import statistics
import datetime
import json
import re
from collections import Counter, defaultdict
""" + code
result = repl.run(augmented_code)
return result.strip() if result else "Code executed successfully with no output. Add print() to see results."
except Exception as e:
return f"Execution error: {str(e)}. Please fix the code and try again."
@tool
def read_file(file_path: str) -> str:
"""
Read and extract content from various file types.
Supports: PDF, TXT, MD, CSV, JSON, XLSX, XLS, PY, and other text files.
Args:
file_path: Path to the file to read
Returns:
The content of the file as a string
"""
try:
if not os.path.exists(file_path):
return f"Error: File not found at {file_path}"
file_lower = file_path.lower()
if file_lower.endswith('.pdf'):
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader(file_path)
pages = loader.load()
content = "\n\n--- Page Break ---\n\n".join([p.page_content for p in pages])
return f"PDF Content ({len(pages)} pages):\n{content}"
elif file_lower.endswith(('.xlsx', '.xls')):
df = pd.read_excel(file_path, sheet_name=None) # Read all sheets
result = []
for sheet_name, sheet_df in df.items():
result.append(f"=== Sheet: {sheet_name} ===\n{sheet_df.to_string()}")
return "\n\n".join(result)
elif file_lower.endswith('.csv'):
df = pd.read_csv(file_path)
return f"CSV Data ({len(df)} rows):\n{df.to_string()}"
elif file_lower.endswith('.json'):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return f"JSON Content:\n{json.dumps(data, indent=2)}"
else: # Default: treat as text
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
return f"File Content:\n{content}"
except Exception as e:
return f"Error reading file: {str(e)}"
@tool
def calculator(expression: str) -> str:
"""
Evaluate a mathematical expression safely.
Supports: basic arithmetic, trigonometry, logarithms, exponents, etc.
Args:
expression: Mathematical expression (e.g., "sqrt(16) + log(100, 10)")
Returns:
The numerical result as a string
"""
try:
import math
# Define allowed functions and constants
safe_dict = {
'abs': abs, 'round': round, 'min': min, 'max': max,
'sum': sum, 'pow': pow, 'len': len,
'sqrt': math.sqrt, 'log': math.log, 'log10': math.log10,
'log2': math.log2, 'exp': math.exp,
'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
'asin': math.asin, 'acos': math.acos, 'atan': math.atan,
'sinh': math.sinh, 'cosh': math.cosh, 'tanh': math.tanh,
'ceil': math.ceil, 'floor': math.floor,
'pi': math.pi, 'e': math.e,
'factorial': math.factorial, 'gcd': math.gcd,
'degrees': math.degrees, 'radians': math.radians,
}
result = eval(expression, {"__builtins__": {}}, safe_dict)
# Format nicely
if isinstance(result, float):
if result.is_integer():
return str(int(result))
return f"{result:.10g}" # Remove trailing zeros
return str(result)
except Exception as e:
return f"Calculation error: {str(e)}. Check your expression syntax."
@tool
def wikipedia_search(query: str) -> str:
"""
Search Wikipedia for factual information about a specific topic.
Best for: historical facts, biographies, scientific concepts, definitions.
Args:
query: The topic to search for on Wikipedia
Returns:
Summary and key information from relevant Wikipedia articles
"""
try:
import urllib.parse
# Search for articles
search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={urllib.parse.quote(query)}&format=json&srlimit=3"
response = requests.get(search_url, timeout=10)
data = response.json()
if 'query' not in data or 'search' not in data['query'] or not data['query']['search']:
return f"No Wikipedia articles found for '{query}'"
# Get full content of top result
top_title = data['query']['search'][0]['title']
content_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=true&explaintext=true&titles={urllib.parse.quote(top_title)}&format=json"
content_response = requests.get(content_url, timeout=10)
content_data = content_response.json()
pages = content_data.get('query', {}).get('pages', {})
for page_id, page_data in pages.items():
if page_id != '-1':
title = page_data.get('title', '')
extract = page_data.get('extract', 'No content available')
return f"Wikipedia: {title}\n\n{extract[:2000]}"
return "Could not retrieve article content."
except Exception as e:
return f"Wikipedia search failed: {str(e)}"
@tool
def analyze_image(image_path: str, question: str) -> str:
"""
Analyze an image file and answer questions about it.
Note: This is a placeholder - implement with vision model if needed.
Args:
image_path: Path to the image file
question: What to analyze or find in the image
Returns:
Description or analysis of the image
"""
# This is a placeholder - you can integrate with GPT-4V or other vision models
return f"Image analysis not implemented. File: {image_path}, Question: {question}"
# Collect all tools
TOOLS = [web_search, python_executor, read_file, calculator, wikipedia_search]
# ============ SYSTEM PROMPT ============
SYSTEM_PROMPT = """You are an expert AI assistant designed to solve GAIA benchmark questions with maximum accuracy.
## Your Mission
Provide PRECISE, EXACT answers. The benchmark uses EXACT STRING MATCHING, so your final answer must match the ground truth character-for-character.
## Critical Answer Formatting Rules (MUST FOLLOW)
**DO NOT include "FINAL ANSWER:" or any prefix - just the answer itself.**
1. **Numbers**: Give just the number.
- ✅ CORRECT: "42"
- ❌ WRONG: "The answer is 42", "42 units", "Answer: 42"
2. **Names**: Exact spelling as found in sources. Check Wikipedia/official sources for correct spelling, capitalization, and punctuation.
- ✅ CORRECT: "John Smith"
- ❌ WRONG: "john smith", "John smith"
3. **Lists**: Comma-separated, NO spaces after commas.
- ✅ CORRECT: "apple,banana,cherry"
- ❌ WRONG: "apple, banana, cherry", "apple,banana, cherry"
4. **Dates**: Use the format specified in the question, or YYYY-MM-DD if not specified.
- ✅ CORRECT: "2024-01-15" or "January 15, 2024" (if question asks for that format)
- ❌ WRONG: "1/15/2024" (unless question asks for it)
5. **Yes/No**: Just "Yes" or "No" (capitalized, no period).
- ✅ CORRECT: "Yes"
- ❌ WRONG: "yes", "Yes.", "The answer is Yes"
6. **Counts**: Just the number.
- ✅ CORRECT: "5"
- ❌ WRONG: "5 items", "five", "There are 5"
7. **No explanations**: Your final response must contain ONLY the answer, nothing else.
- ✅ CORRECT: "Paris"
- ❌ WRONG: "The answer is Paris because..."
## Problem-Solving Strategy
1. **Understand**: Read the question carefully. What exactly is being asked? Note any specific format requirements.
2. **Check for File**: If a file is mentioned or available, ALWAYS read it FIRST - the answer is likely there.
3. **Plan**: What information do I need? Which tools should I use?
4. **Execute**: Use tools systematically. Verify information from multiple sources when possible.
5. **Verify**: Double-check your answer format. Does it match the question's requirements? Is spelling correct?
6. **Respond**: Give ONLY the final answer, no prefixes, no explanations.
## Available Tools
- `read_file`: Read PDFs, spreadsheets, text files - USE THIS FIRST if a file is available
- `web_search`: Current information, recent events, facts
- `wikipedia_search`: Historical facts, biographies, definitions
- `python_executor`: Calculations, data processing, analysis
- `calculator`: Quick mathematical calculations
## Tool Usage Priority
1. **If file available**: Read file FIRST before doing anything else
2. **For calculations**: Use python_executor for complex math, calculator for simple expressions
3. **For facts**: Use wikipedia_search for established facts, web_search for current/recent information
4. **Cross-reference**: When possible, verify important facts from multiple sources
## Critical Reminders
- NEVER include "FINAL ANSWER:" or any prefix in your response
- NEVER add explanations or context to your final answer
- ALWAYS verify spelling, capitalization, and formatting
- ALWAYS read files first if they are available
- If uncertain about format, look for clues in the question itself
- Never guess - use tools to find accurate information
Remember: Your final message must contain ONLY the answer, nothing else. The scoring system uses exact string matching."""
# ============ LANGGRAPH AGENT ============
class GAIAAgent:
"""LangGraph-based agent for GAIA benchmark."""
def __init__(
self,
model_name: str = "gpt-4o",
api_key: str = None,
temperature: float = 0,
max_iterations: int = 15
):
"""
Initialize the GAIA agent.
Args:
model_name: OpenAI model to use
api_key: OpenAI API key (or set OPENAI_API_KEY env var)
temperature: Model temperature (0 for deterministic)
max_iterations: Maximum tool-use iterations
"""
self.model_name = model_name
self.max_iterations = max_iterations
self.llm = ChatOpenAI(
model=model_name,
temperature=temperature,
api_key=api_key or os.environ.get("OPENAI_API_KEY")
)
self.llm_with_tools = self.llm.bind_tools(TOOLS)
self.graph = self._build_graph()
def _build_graph(self) -> StateGraph:
"""Construct the LangGraph workflow."""
workflow = StateGraph(AgentState)
# Define nodes
workflow.add_node("agent", self._agent_node)
workflow.add_node("tools", ToolNode(TOOLS))
workflow.add_node("extract_answer", self._extract_answer_node)
# Set entry point
workflow.set_entry_point("agent")
# Define edges
workflow.add_conditional_edges(
"agent",
self._route_agent_output,
{
"tools": "tools",
"end": "extract_answer"
}
)
workflow.add_edge("tools", "agent")
workflow.add_edge("extract_answer", END)
return workflow.compile()
def _agent_node(self, state: AgentState) -> dict:
"""Process messages and decide on next action."""
messages = state["messages"]
iteration = state.get("iteration_count", 0)
# Add iteration warnings earlier to give agent more time to finish
if iteration >= self.max_iterations - 3:
warning_msg = "WARNING: Approaching iteration limit. Please provide your final answer now. Remember: just the answer, no prefix."
messages = list(messages) + [SystemMessage(content=warning_msg)]
elif iteration >= self.max_iterations - 5:
reminder_msg = "Reminder: When you're ready to answer, provide ONLY the final answer with no prefix like 'FINAL ANSWER:' or 'The answer is:'"
messages = list(messages) + [SystemMessage(content=reminder_msg)]
try:
response = self.llm_with_tools.invoke(messages)
except Exception as e:
# Graceful error handling
error_msg = AIMessage(content=f"Error during reasoning: {str(e)}. Please try a different approach or provide your best answer.")
return {
"messages": [error_msg],
"iteration_count": iteration + 1
}
return {
"messages": [response],
"iteration_count": iteration + 1
}
def _route_agent_output(self, state: AgentState) -> Literal["tools", "end"]:
"""Determine whether to use tools or finish."""
last_message = state["messages"][-1]
iteration = state.get("iteration_count", 0)
# Force end if max iterations reached
if iteration >= self.max_iterations:
return "end"
# Check if agent wants to use tools
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
return "end"
def _extract_answer_node(self, state: AgentState) -> dict:
"""Extract and clean the final answer."""
last_message = state["messages"][-1]
content = last_message.content if hasattr(last_message, "content") else str(last_message)
answer = self._clean_answer(content)
return {"final_answer": answer}
def _clean_answer(self, raw_answer: str) -> str:
"""Clean and format the final answer for exact matching."""
answer = raw_answer.strip()
# Remove common prefixes (case-insensitive, with variations)
prefixes = [
"the answer is:", "the answer is", "answer is:",
"answer:", "answer", "answer:",
"final answer:", "final answer", "FINAL ANSWER:", "FINAL ANSWER",
"the final answer is:", "the final answer is",
"result:", "result", "result is:",
"solution:", "solution", "solution is:",
"the solution is:", "the solution is",
"it is", "it's", "that is", "that's",
]
answer_lower = answer.lower()
for prefix in prefixes:
if answer_lower.startswith(prefix):
answer = answer[len(prefix):].strip()
# Remove any leading colon or dash
answer = answer.lstrip(':').lstrip('-').strip()
answer_lower = answer.lower()
# Remove quotes if they wrap the entire answer
if (answer.startswith('"') and answer.endswith('"')) or \
(answer.startswith("'") and answer.endswith("'")):
answer = answer[1:-1].strip()
# Remove trailing periods, commas, or semicolons for single-word/number answers
if answer and ' ' not in answer:
answer = answer.rstrip('.,;:')
# Remove leading/trailing whitespace and normalize internal whitespace
answer = ' '.join(answer.split())
# Remove markdown formatting if present
if answer.startswith('**') and answer.endswith('**'):
answer = answer[2:-2]
if answer.startswith('*') and answer.endswith('*'):
answer = answer[1:-1]
return answer.strip()
def run(self, question: str, task_id: str = "", file_path: str = None) -> str:
"""
Run the agent on a question.
Args:
question: The GAIA question to answer
task_id: Optional task identifier
file_path: Optional path to associated file
Returns:
The agent's final answer
"""
# Prepare the user message with file priority
user_content = question
if file_path and os.path.exists(file_path):
# Strongly emphasize reading the file first
user_content = f"[IMPORTANT: A file is available at {file_path}]\n\nYou MUST read this file FIRST using the read_file tool before attempting to answer. The answer is very likely contained in this file.\n\nQuestion: {question}"
# Initialize state
initial_state: AgentState = {
"messages": [
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=user_content)
],
"task_id": task_id,
"file_path": file_path,
"file_content": None,
"iteration_count": 0,
"final_answer": None
}
# Execute the graph
try:
final_state = self.graph.invoke(
initial_state,
{"recursion_limit": self.max_iterations * 2 + 5}
)
answer = final_state.get("final_answer", "Unable to determine answer")
# Final validation - ensure answer is not empty or error message
if not answer or answer.startswith("Agent error:") or answer.startswith("Unable to determine"):
# Try to extract from last message if available
if final_state.get("messages"):
last_msg = final_state["messages"][-1]
if hasattr(last_msg, "content") and last_msg.content:
answer = self._clean_answer(last_msg.content)
return answer if answer else "Unable to determine answer"
except Exception as e:
# Log error for debugging but return a clean error message
import logging
logging.error(f"Agent execution error: {str(e)}")
return f"Agent error: {str(e)}"
# ============ UTILITY FUNCTIONS ============
def create_agent(api_key: str = None, model: str = "gpt-4o") -> GAIAAgent:
"""Factory function to create a configured agent."""
return GAIAAgent(
model_name=model,
api_key=api_key,
temperature=0,
max_iterations=15
)