File size: 11,322 Bytes
223e45d 88cb2f4 223e45d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
"""
GAIA Agent with Essential Tools for 30%+ Accuracy
Built with LangGraph and Groq LLM
"""
import os
import re
import json
from typing import Annotated
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver
# Initialize LLM
def get_llm():
"""Get Groq LLM instance"""
return ChatGroq(
model="llama-3.1-8b-instant",
temperature=0,
max_tokens=8000,
timeout=60,
max_retries=2
)
# ============================================================================
# TOOL DEFINITIONS
# ============================================================================
@tool
def web_search(query: str) -> str:
"""
Search the web for current information using Tavily.
Use this for finding recent information, facts, statistics, or any data not in your training.
Args:
query: The search query string
Returns:
Search results as formatted text
"""
try:
tavily = TavilySearchResults(
max_results=5,
search_depth="advanced",
include_answer=True,
include_raw_content=False
)
results = tavily.invoke(query)
if not results:
return "No results found."
# Format results nicely
formatted = []
for i, result in enumerate(results, 1):
title = result.get('title', 'No title')
content = result.get('content', 'No content')
url = result.get('url', '')
formatted.append(f"Result {i}:\nTitle: {title}\nContent: {content}\nURL: {url}\n")
return "\n".join(formatted)
except Exception as e:
return f"Error searching web: {str(e)}"
@tool
def wikipedia_search(query: str) -> str:
"""
Search Wikipedia for encyclopedic information.
Use this for historical facts, biographies, scientific concepts, etc.
Args:
query: The Wikipedia search query
Returns:
Wikipedia article content
"""
try:
loader = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=4000)
docs = loader.load()
if not docs:
return f"No Wikipedia article found for '{query}'"
# Combine the documents
content = "\n\n---\n\n".join([doc.page_content for doc in docs])
return f"Wikipedia results for '{query}':\n\n{content}"
except Exception as e:
return f"Error searching Wikipedia: {str(e)}"
@tool
def calculate(expression: str) -> str:
"""
Evaluate a mathematical expression safely.
Supports basic arithmetic: +, -, *, /, //, %, **, parentheses.
Also supports common math functions: abs, round, min, max, sum.
Args:
expression: Mathematical expression as a string (e.g., "2 + 2", "sqrt(16)", "10 ** 2")
Returns:
The calculated result
"""
try:
# Import math for advanced functions
import math
# Create a safe namespace with math functions
safe_dict = {
'abs': abs, 'round': round, 'min': min, 'max': max, 'sum': sum,
'sqrt': math.sqrt, 'pow': pow, 'log': math.log, 'log10': math.log10,
'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
'pi': math.pi, 'e': math.e, 'ceil': math.ceil, 'floor': math.floor
}
# Clean the expression
expression = expression.strip()
# Evaluate safely
result = eval(expression, {"__builtins__": {}}, safe_dict)
return str(result)
except Exception as e:
return f"Error calculating '{expression}': {str(e)}"
@tool
def python_executor(code: str) -> str:
"""
Execute Python code safely for data processing and calculations.
Use this for complex calculations, data manipulation, or multi-step computations.
The code should print its output.
Args:
code: Python code to execute
Returns:
The output of the code execution
"""
try:
import io
import sys
import math
import json
from datetime import datetime, timedelta
# Capture stdout
old_stdout = sys.stdout
sys.stdout = buffer = io.StringIO()
# Create safe execution environment
safe_globals = {
'__builtins__': {
'print': print, 'len': len, 'range': range, 'str': str,
'int': int, 'float': float, 'list': list, 'dict': dict,
'set': set, 'tuple': tuple, 'sorted': sorted, 'sum': sum,
'min': min, 'max': max, 'abs': abs, 'round': round,
'enumerate': enumerate, 'zip': zip, 'map': map, 'filter': filter,
},
'math': math,
'json': json,
'datetime': datetime,
'timedelta': timedelta,
}
# Execute code
exec(code, safe_globals)
# Get output
sys.stdout = old_stdout
output = buffer.getvalue()
return output if output else "Code executed successfully (no output)"
except Exception as e:
sys.stdout = old_stdout
return f"Error executing code: {str(e)}"
@tool
def read_file(filepath: str) -> str:
"""
Read and return the contents of a file.
Supports text files, CSV, JSON, and basic file formats.
Args:
filepath: Path to the file to read
Returns:
File contents as string
"""
try:
# Check if file exists
if not os.path.exists(filepath):
return f"File not found: {filepath}"
# Read based on file extension
if filepath.endswith('.json'):
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return json.dumps(data, indent=2)
elif filepath.endswith('.csv'):
try:
import pandas as pd
df = pd.read_csv(filepath)
return f"CSV file with {len(df)} rows and {len(df.columns)} columns:\n\n{df.to_string()}"
except ImportError:
# Fallback if pandas not available
with open(filepath, 'r', encoding='utf-8') as f:
return f.read()
else:
# Read as text
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
return f"Error reading file '{filepath}': {str(e)}"
# ============================================================================
# SYSTEM PROMPT - GAIA Specific Instructions
# ============================================================================
GAIA_SYSTEM_PROMPT = """You are a helpful AI assistant designed to answer questions from the GAIA benchmark.
CRITICAL ANSWER FORMAT RULES:
1. For numbers: NO commas, NO units (unless explicitly requested)
- CORRECT: "1000" or "1000 meters" (if units requested)
- WRONG: "1,000" or "1000 meters" (if units not requested)
2. For text answers: No articles (a, an, the), no abbreviations
- CORRECT: "United States"
- WRONG: "The United States" or "USA"
3. For lists: Comma-separated with one space after each comma
- CORRECT: "apple, banana, orange"
- WRONG: "apple,banana,orange" or "apple, banana, orange."
4. For dates: Use the format specified in the question
- If not specified, use ISO format: YYYY-MM-DD
5. Be precise and concise - answer ONLY what is asked
APPROACH:
1. Read the question carefully and identify what information is needed
2. Use tools to gather information (web search, Wikipedia, calculations)
3. For multi-step questions, break down the problem and solve step by step
4. Verify your answer matches the format requirements above
5. Return ONLY the final answer in the correct format
AVAILABLE TOOLS:
- web_search: Search the internet for current information
- wikipedia_search: Search Wikipedia for encyclopedic knowledge
- calculate: Perform mathematical calculations
- python_executor: Execute Python code for complex computations
- read_file: Read files (CSV, JSON, text)
Remember: Your final response should be ONLY the answer in the correct format, nothing else.
"""
# ============================================================================
# AGENT GRAPH CONSTRUCTION
# ============================================================================
def build_graph():
"""Build the LangGraph agent with tools"""
# Initialize LLM
llm = get_llm()
# Define tools
tools = [
web_search,
wikipedia_search,
calculate,
python_executor,
read_file
]
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Define the assistant node
def assistant(state: MessagesState):
"""Assistant node that calls the LLM"""
messages = state["messages"]
# Add system message if not present
if not any(isinstance(msg, SystemMessage) for msg in messages):
messages = [SystemMessage(content=GAIA_SYSTEM_PROMPT)] + messages
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
# Build the graph
builder = StateGraph(MessagesState)
# Add nodes
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
# Add edges
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile with memory
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
return graph
# ============================================================================
# TESTING
# ============================================================================
if __name__ == "__main__":
"""Test the agent with sample questions"""
from langchain_core.messages import HumanMessage
# Build agent
print("Building agent...")
agent = build_graph()
# Test questions
test_questions = [
"What is 25 * 4 + 100?",
"Who was the first president of the United States?",
"Search for the population of Tokyo in 2024"
]
for i, question in enumerate(test_questions, 1):
print(f"\n{'='*60}")
print(f"Test {i}: {question}")
print('='*60)
try:
config = {"configurable": {"thread_id": f"test_{i}"}}
result = agent.invoke(
{"messages": [HumanMessage(content=question)]},
config=config
)
answer = result['messages'][-1].content
print(f"Answer: {answer}")
except Exception as e:
print(f"Error: {e}")
|