|
|
""" |
|
|
GAIA Agent with Essential Tools for 30%+ Accuracy |
|
|
Built with LangGraph and Groq LLM |
|
|
""" |
|
|
import os |
|
|
import re |
|
|
import json |
|
|
from typing import Annotated |
|
|
from langchain_core.tools import tool |
|
|
from langchain_core.messages import SystemMessage |
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
|
from langchain_community.document_loaders import WikipediaLoader |
|
|
from langchain_groq import ChatGroq |
|
|
from langgraph.graph import StateGraph, MessagesState, START, END |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
|
|
|
def get_llm(): |
|
|
"""Get Groq LLM instance""" |
|
|
return ChatGroq( |
|
|
model="llama-3.1-8b-instant", |
|
|
temperature=0, |
|
|
max_tokens=8000, |
|
|
timeout=60, |
|
|
max_retries=2 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
|
def web_search(query: str) -> str: |
|
|
""" |
|
|
Search the web for current information using Tavily. |
|
|
Use this for finding recent information, facts, statistics, or any data not in your training. |
|
|
|
|
|
Args: |
|
|
query: The search query string |
|
|
|
|
|
Returns: |
|
|
Search results as formatted text |
|
|
""" |
|
|
try: |
|
|
tavily = TavilySearchResults( |
|
|
max_results=5, |
|
|
search_depth="advanced", |
|
|
include_answer=True, |
|
|
include_raw_content=False |
|
|
) |
|
|
results = tavily.invoke(query) |
|
|
|
|
|
if not results: |
|
|
return "No results found." |
|
|
|
|
|
|
|
|
formatted = [] |
|
|
for i, result in enumerate(results, 1): |
|
|
title = result.get('title', 'No title') |
|
|
content = result.get('content', 'No content') |
|
|
url = result.get('url', '') |
|
|
formatted.append(f"Result {i}:\nTitle: {title}\nContent: {content}\nURL: {url}\n") |
|
|
|
|
|
return "\n".join(formatted) |
|
|
except Exception as e: |
|
|
return f"Error searching web: {str(e)}" |
|
|
|
|
|
|
|
|
@tool |
|
|
def wikipedia_search(query: str) -> str: |
|
|
""" |
|
|
Search Wikipedia for encyclopedic information. |
|
|
Use this for historical facts, biographies, scientific concepts, etc. |
|
|
|
|
|
Args: |
|
|
query: The Wikipedia search query |
|
|
|
|
|
Returns: |
|
|
Wikipedia article content |
|
|
""" |
|
|
try: |
|
|
loader = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=4000) |
|
|
docs = loader.load() |
|
|
|
|
|
if not docs: |
|
|
return f"No Wikipedia article found for '{query}'" |
|
|
|
|
|
|
|
|
content = "\n\n---\n\n".join([doc.page_content for doc in docs]) |
|
|
return f"Wikipedia results for '{query}':\n\n{content}" |
|
|
except Exception as e: |
|
|
return f"Error searching Wikipedia: {str(e)}" |
|
|
|
|
|
|
|
|
@tool |
|
|
def calculate(expression: str) -> str: |
|
|
""" |
|
|
Evaluate a mathematical expression safely. |
|
|
Supports basic arithmetic: +, -, *, /, //, %, **, parentheses. |
|
|
Also supports common math functions: abs, round, min, max, sum. |
|
|
|
|
|
Args: |
|
|
expression: Mathematical expression as a string (e.g., "2 + 2", "sqrt(16)", "10 ** 2") |
|
|
|
|
|
Returns: |
|
|
The calculated result |
|
|
""" |
|
|
try: |
|
|
|
|
|
import math |
|
|
|
|
|
|
|
|
safe_dict = { |
|
|
'abs': abs, 'round': round, 'min': min, 'max': max, 'sum': sum, |
|
|
'sqrt': math.sqrt, 'pow': pow, 'log': math.log, 'log10': math.log10, |
|
|
'sin': math.sin, 'cos': math.cos, 'tan': math.tan, |
|
|
'pi': math.pi, 'e': math.e, 'ceil': math.ceil, 'floor': math.floor |
|
|
} |
|
|
|
|
|
|
|
|
expression = expression.strip() |
|
|
|
|
|
|
|
|
result = eval(expression, {"__builtins__": {}}, safe_dict) |
|
|
return str(result) |
|
|
except Exception as e: |
|
|
return f"Error calculating '{expression}': {str(e)}" |
|
|
|
|
|
|
|
|
@tool |
|
|
def python_executor(code: str) -> str: |
|
|
""" |
|
|
Execute Python code safely for data processing and calculations. |
|
|
Use this for complex calculations, data manipulation, or multi-step computations. |
|
|
The code should print its output. |
|
|
|
|
|
Args: |
|
|
code: Python code to execute |
|
|
|
|
|
Returns: |
|
|
The output of the code execution |
|
|
""" |
|
|
try: |
|
|
import io |
|
|
import sys |
|
|
import math |
|
|
import json |
|
|
from datetime import datetime, timedelta |
|
|
|
|
|
|
|
|
old_stdout = sys.stdout |
|
|
sys.stdout = buffer = io.StringIO() |
|
|
|
|
|
|
|
|
safe_globals = { |
|
|
'__builtins__': { |
|
|
'print': print, 'len': len, 'range': range, 'str': str, |
|
|
'int': int, 'float': float, 'list': list, 'dict': dict, |
|
|
'set': set, 'tuple': tuple, 'sorted': sorted, 'sum': sum, |
|
|
'min': min, 'max': max, 'abs': abs, 'round': round, |
|
|
'enumerate': enumerate, 'zip': zip, 'map': map, 'filter': filter, |
|
|
}, |
|
|
'math': math, |
|
|
'json': json, |
|
|
'datetime': datetime, |
|
|
'timedelta': timedelta, |
|
|
} |
|
|
|
|
|
|
|
|
exec(code, safe_globals) |
|
|
|
|
|
|
|
|
sys.stdout = old_stdout |
|
|
output = buffer.getvalue() |
|
|
|
|
|
return output if output else "Code executed successfully (no output)" |
|
|
except Exception as e: |
|
|
sys.stdout = old_stdout |
|
|
return f"Error executing code: {str(e)}" |
|
|
|
|
|
|
|
|
@tool |
|
|
def read_file(filepath: str) -> str: |
|
|
""" |
|
|
Read and return the contents of a file. |
|
|
Supports text files, CSV, JSON, and basic file formats. |
|
|
|
|
|
Args: |
|
|
filepath: Path to the file to read |
|
|
|
|
|
Returns: |
|
|
File contents as string |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not os.path.exists(filepath): |
|
|
return f"File not found: {filepath}" |
|
|
|
|
|
|
|
|
if filepath.endswith('.json'): |
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
return json.dumps(data, indent=2) |
|
|
|
|
|
elif filepath.endswith('.csv'): |
|
|
try: |
|
|
import pandas as pd |
|
|
df = pd.read_csv(filepath) |
|
|
return f"CSV file with {len(df)} rows and {len(df.columns)} columns:\n\n{df.to_string()}" |
|
|
except ImportError: |
|
|
|
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
return f.read() |
|
|
|
|
|
else: |
|
|
|
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
return content |
|
|
except Exception as e: |
|
|
return f"Error reading file '{filepath}': {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GAIA_SYSTEM_PROMPT = """You are a helpful AI assistant designed to answer questions from the GAIA benchmark. |
|
|
|
|
|
CRITICAL ANSWER FORMAT RULES: |
|
|
1. For numbers: NO commas, NO units (unless explicitly requested) |
|
|
- CORRECT: "1000" or "1000 meters" (if units requested) |
|
|
- WRONG: "1,000" or "1000 meters" (if units not requested) |
|
|
|
|
|
2. For text answers: No articles (a, an, the), no abbreviations |
|
|
- CORRECT: "United States" |
|
|
- WRONG: "The United States" or "USA" |
|
|
|
|
|
3. For lists: Comma-separated with one space after each comma |
|
|
- CORRECT: "apple, banana, orange" |
|
|
- WRONG: "apple,banana,orange" or "apple, banana, orange." |
|
|
|
|
|
4. For dates: Use the format specified in the question |
|
|
- If not specified, use ISO format: YYYY-MM-DD |
|
|
|
|
|
5. Be precise and concise - answer ONLY what is asked |
|
|
|
|
|
APPROACH: |
|
|
1. Read the question carefully and identify what information is needed |
|
|
2. Use tools to gather information (web search, Wikipedia, calculations) |
|
|
3. For multi-step questions, break down the problem and solve step by step |
|
|
4. Verify your answer matches the format requirements above |
|
|
5. Return ONLY the final answer in the correct format |
|
|
|
|
|
AVAILABLE TOOLS: |
|
|
- web_search: Search the internet for current information |
|
|
- wikipedia_search: Search Wikipedia for encyclopedic knowledge |
|
|
- calculate: Perform mathematical calculations |
|
|
- python_executor: Execute Python code for complex computations |
|
|
- read_file: Read files (CSV, JSON, text) |
|
|
|
|
|
Remember: Your final response should be ONLY the answer in the correct format, nothing else. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_graph(): |
|
|
"""Build the LangGraph agent with tools""" |
|
|
|
|
|
|
|
|
llm = get_llm() |
|
|
|
|
|
|
|
|
tools = [ |
|
|
web_search, |
|
|
wikipedia_search, |
|
|
calculate, |
|
|
python_executor, |
|
|
read_file |
|
|
] |
|
|
|
|
|
|
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
|
|
|
def assistant(state: MessagesState): |
|
|
"""Assistant node that calls the LLM""" |
|
|
messages = state["messages"] |
|
|
|
|
|
|
|
|
if not any(isinstance(msg, SystemMessage) for msg in messages): |
|
|
messages = [SystemMessage(content=GAIA_SYSTEM_PROMPT)] + messages |
|
|
|
|
|
response = llm_with_tools.invoke(messages) |
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
builder = StateGraph(MessagesState) |
|
|
|
|
|
|
|
|
builder.add_node("assistant", assistant) |
|
|
builder.add_node("tools", ToolNode(tools)) |
|
|
|
|
|
|
|
|
builder.add_edge(START, "assistant") |
|
|
builder.add_conditional_edges( |
|
|
"assistant", |
|
|
tools_condition, |
|
|
) |
|
|
builder.add_edge("tools", "assistant") |
|
|
|
|
|
|
|
|
memory = MemorySaver() |
|
|
graph = builder.compile(checkpointer=memory) |
|
|
|
|
|
return graph |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
"""Test the agent with sample questions""" |
|
|
from langchain_core.messages import HumanMessage |
|
|
|
|
|
|
|
|
print("Building agent...") |
|
|
agent = build_graph() |
|
|
|
|
|
|
|
|
test_questions = [ |
|
|
"What is 25 * 4 + 100?", |
|
|
"Who was the first president of the United States?", |
|
|
"Search for the population of Tokyo in 2024" |
|
|
] |
|
|
|
|
|
for i, question in enumerate(test_questions, 1): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Test {i}: {question}") |
|
|
print('='*60) |
|
|
|
|
|
try: |
|
|
config = {"configurable": {"thread_id": f"test_{i}"}} |
|
|
result = agent.invoke( |
|
|
{"messages": [HumanMessage(content=question)]}, |
|
|
config=config |
|
|
) |
|
|
answer = result['messages'][-1].content |
|
|
print(f"Answer: {answer}") |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
|