""" GAIA Agent with Essential Tools for 30%+ Accuracy Built with LangGraph and Groq LLM """ import os import re import json from typing import Annotated from langchain_core.tools import tool from langchain_core.messages import SystemMessage from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader from langchain_groq import ChatGroq from langgraph.graph import StateGraph, MessagesState, START, END from langgraph.prebuilt import ToolNode, tools_condition from langgraph.checkpoint.memory import MemorySaver # Initialize LLM def get_llm(): """Get Groq LLM instance""" return ChatGroq( model="llama-3.1-8b-instant", temperature=0, max_tokens=8000, timeout=60, max_retries=2 ) # ============================================================================ # TOOL DEFINITIONS # ============================================================================ @tool def web_search(query: str) -> str: """ Search the web for current information using Tavily. Use this for finding recent information, facts, statistics, or any data not in your training. Args: query: The search query string Returns: Search results as formatted text """ try: tavily = TavilySearchResults( max_results=5, search_depth="advanced", include_answer=True, include_raw_content=False ) results = tavily.invoke(query) if not results: return "No results found." # Format results nicely formatted = [] for i, result in enumerate(results, 1): title = result.get('title', 'No title') content = result.get('content', 'No content') url = result.get('url', '') formatted.append(f"Result {i}:\nTitle: {title}\nContent: {content}\nURL: {url}\n") return "\n".join(formatted) except Exception as e: return f"Error searching web: {str(e)}" @tool def wikipedia_search(query: str) -> str: """ Search Wikipedia for encyclopedic information. Use this for historical facts, biographies, scientific concepts, etc. Args: query: The Wikipedia search query Returns: Wikipedia article content """ try: loader = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=4000) docs = loader.load() if not docs: return f"No Wikipedia article found for '{query}'" # Combine the documents content = "\n\n---\n\n".join([doc.page_content for doc in docs]) return f"Wikipedia results for '{query}':\n\n{content}" except Exception as e: return f"Error searching Wikipedia: {str(e)}" @tool def calculate(expression: str) -> str: """ Evaluate a mathematical expression safely. Supports basic arithmetic: +, -, *, /, //, %, **, parentheses. Also supports common math functions: abs, round, min, max, sum. Args: expression: Mathematical expression as a string (e.g., "2 + 2", "sqrt(16)", "10 ** 2") Returns: The calculated result """ try: # Import math for advanced functions import math # Create a safe namespace with math functions safe_dict = { 'abs': abs, 'round': round, 'min': min, 'max': max, 'sum': sum, 'sqrt': math.sqrt, 'pow': pow, 'log': math.log, 'log10': math.log10, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, 'pi': math.pi, 'e': math.e, 'ceil': math.ceil, 'floor': math.floor } # Clean the expression expression = expression.strip() # Evaluate safely result = eval(expression, {"__builtins__": {}}, safe_dict) return str(result) except Exception as e: return f"Error calculating '{expression}': {str(e)}" @tool def python_executor(code: str) -> str: """ Execute Python code safely for data processing and calculations. Use this for complex calculations, data manipulation, or multi-step computations. The code should print its output. Args: code: Python code to execute Returns: The output of the code execution """ try: import io import sys import math import json from datetime import datetime, timedelta # Capture stdout old_stdout = sys.stdout sys.stdout = buffer = io.StringIO() # Create safe execution environment safe_globals = { '__builtins__': { 'print': print, 'len': len, 'range': range, 'str': str, 'int': int, 'float': float, 'list': list, 'dict': dict, 'set': set, 'tuple': tuple, 'sorted': sorted, 'sum': sum, 'min': min, 'max': max, 'abs': abs, 'round': round, 'enumerate': enumerate, 'zip': zip, 'map': map, 'filter': filter, }, 'math': math, 'json': json, 'datetime': datetime, 'timedelta': timedelta, } # Execute code exec(code, safe_globals) # Get output sys.stdout = old_stdout output = buffer.getvalue() return output if output else "Code executed successfully (no output)" except Exception as e: sys.stdout = old_stdout return f"Error executing code: {str(e)}" @tool def read_file(filepath: str) -> str: """ Read and return the contents of a file. Supports text files, CSV, JSON, and basic file formats. Args: filepath: Path to the file to read Returns: File contents as string """ try: # Check if file exists if not os.path.exists(filepath): return f"File not found: {filepath}" # Read based on file extension if filepath.endswith('.json'): with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) return json.dumps(data, indent=2) elif filepath.endswith('.csv'): try: import pandas as pd df = pd.read_csv(filepath) return f"CSV file with {len(df)} rows and {len(df.columns)} columns:\n\n{df.to_string()}" except ImportError: # Fallback if pandas not available with open(filepath, 'r', encoding='utf-8') as f: return f.read() else: # Read as text with open(filepath, 'r', encoding='utf-8') as f: content = f.read() return content except Exception as e: return f"Error reading file '{filepath}': {str(e)}" # ============================================================================ # SYSTEM PROMPT - GAIA Specific Instructions # ============================================================================ GAIA_SYSTEM_PROMPT = """You are a helpful AI assistant designed to answer questions from the GAIA benchmark. CRITICAL ANSWER FORMAT RULES: 1. For numbers: NO commas, NO units (unless explicitly requested) - CORRECT: "1000" or "1000 meters" (if units requested) - WRONG: "1,000" or "1000 meters" (if units not requested) 2. For text answers: No articles (a, an, the), no abbreviations - CORRECT: "United States" - WRONG: "The United States" or "USA" 3. For lists: Comma-separated with one space after each comma - CORRECT: "apple, banana, orange" - WRONG: "apple,banana,orange" or "apple, banana, orange." 4. For dates: Use the format specified in the question - If not specified, use ISO format: YYYY-MM-DD 5. Be precise and concise - answer ONLY what is asked APPROACH: 1. Read the question carefully and identify what information is needed 2. Use tools to gather information (web search, Wikipedia, calculations) 3. For multi-step questions, break down the problem and solve step by step 4. Verify your answer matches the format requirements above 5. Return ONLY the final answer in the correct format AVAILABLE TOOLS: - web_search: Search the internet for current information - wikipedia_search: Search Wikipedia for encyclopedic knowledge - calculate: Perform mathematical calculations - python_executor: Execute Python code for complex computations - read_file: Read files (CSV, JSON, text) Remember: Your final response should be ONLY the answer in the correct format, nothing else. """ # ============================================================================ # AGENT GRAPH CONSTRUCTION # ============================================================================ def build_graph(): """Build the LangGraph agent with tools""" # Initialize LLM llm = get_llm() # Define tools tools = [ web_search, wikipedia_search, calculate, python_executor, read_file ] # Bind tools to LLM llm_with_tools = llm.bind_tools(tools) # Define the assistant node def assistant(state: MessagesState): """Assistant node that calls the LLM""" messages = state["messages"] # Add system message if not present if not any(isinstance(msg, SystemMessage) for msg in messages): messages = [SystemMessage(content=GAIA_SYSTEM_PROMPT)] + messages response = llm_with_tools.invoke(messages) return {"messages": [response]} # Build the graph builder = StateGraph(MessagesState) # Add nodes builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) # Add edges builder.add_edge(START, "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") # Compile with memory memory = MemorySaver() graph = builder.compile(checkpointer=memory) return graph # ============================================================================ # TESTING # ============================================================================ if __name__ == "__main__": """Test the agent with sample questions""" from langchain_core.messages import HumanMessage # Build agent print("Building agent...") agent = build_graph() # Test questions test_questions = [ "What is 25 * 4 + 100?", "Who was the first president of the United States?", "Search for the population of Tokyo in 2024" ] for i, question in enumerate(test_questions, 1): print(f"\n{'='*60}") print(f"Test {i}: {question}") print('='*60) try: config = {"configurable": {"thread_id": f"test_{i}"}} result = agent.invoke( {"messages": [HumanMessage(content=question)]}, config=config ) answer = result['messages'][-1].content print(f"Answer: {answer}") except Exception as e: print(f"Error: {e}")