Spaces:
No application file
No application file
| """ | |
| Enhanced GAIA Agent with LangGraph | |
| Separate module for cleaner architecture and easier customization | |
| """ | |
| import os | |
| import re | |
| import json | |
| import requests | |
| import tempfile | |
| from typing import TypedDict, Annotated, Sequence, Literal, Any | |
| import operator | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage | |
| from langchain_core.tools import tool | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.tools import DuckDuckGoSearchResults | |
| from langchain_experimental.utilities import PythonREPL | |
| import pandas as pd | |
| # ============ STATE DEFINITION ============ | |
| class AgentState(TypedDict): | |
| """State maintained throughout the agent's execution.""" | |
| messages: Annotated[Sequence[BaseMessage], operator.add] | |
| task_id: str | |
| file_path: str | None | |
| file_content: str | None | |
| iteration_count: int | |
| final_answer: str | None | |
| # ============ TOOL DEFINITIONS ============ | |
| def web_search(query: str) -> str: | |
| """ | |
| Search the web using DuckDuckGo for current information. | |
| Use this for questions about recent events, facts, statistics, or any information | |
| that might have changed or that you're uncertain about. | |
| Args: | |
| query: The search query string | |
| Returns: | |
| Search results with relevant snippets | |
| """ | |
| try: | |
| search = DuckDuckGoSearchResults(max_results=5, output_format="list") | |
| results = search.run(query) | |
| if isinstance(results, list): | |
| formatted = [] | |
| for r in results: | |
| if isinstance(r, dict): | |
| formatted.append(f"Title: {r.get('title', 'N/A')}\nSnippet: {r.get('snippet', 'N/A')}\nLink: {r.get('link', 'N/A')}") | |
| else: | |
| formatted.append(str(r)) | |
| return "\n\n---\n\n".join(formatted) | |
| return str(results) | |
| except Exception as e: | |
| return f"Search failed: {str(e)}. Try a different query or approach." | |
| def python_executor(code: str) -> str: | |
| """ | |
| Execute Python code for calculations, data analysis, or any computational task. | |
| You have access to standard libraries: math, statistics, datetime, json, re, collections. | |
| Args: | |
| code: Python code to execute. Print statements will show in output. | |
| Returns: | |
| The output/result of the code execution | |
| """ | |
| try: | |
| repl = PythonREPL() | |
| # Add common imports to the code | |
| augmented_code = """ | |
| import math | |
| import statistics | |
| import datetime | |
| import json | |
| import re | |
| from collections import Counter, defaultdict | |
| """ + code | |
| result = repl.run(augmented_code) | |
| return result.strip() if result else "Code executed successfully with no output. Add print() to see results." | |
| except Exception as e: | |
| return f"Execution error: {str(e)}. Please fix the code and try again." | |
| def read_file(file_path: str) -> str: | |
| """ | |
| Read and extract content from various file types. | |
| Supports: PDF, TXT, MD, CSV, JSON, XLSX, XLS, PY, and other text files. | |
| Args: | |
| file_path: Path to the file to read | |
| Returns: | |
| The content of the file as a string | |
| """ | |
| try: | |
| if not os.path.exists(file_path): | |
| return f"Error: File not found at {file_path}" | |
| file_lower = file_path.lower() | |
| if file_lower.endswith('.pdf'): | |
| from langchain_community.document_loaders import PyPDFLoader | |
| loader = PyPDFLoader(file_path) | |
| pages = loader.load() | |
| content = "\n\n--- Page Break ---\n\n".join([p.page_content for p in pages]) | |
| return f"PDF Content ({len(pages)} pages):\n{content}" | |
| elif file_lower.endswith(('.xlsx', '.xls')): | |
| df = pd.read_excel(file_path, sheet_name=None) # Read all sheets | |
| result = [] | |
| for sheet_name, sheet_df in df.items(): | |
| result.append(f"=== Sheet: {sheet_name} ===\n{sheet_df.to_string()}") | |
| return "\n\n".join(result) | |
| elif file_lower.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| return f"CSV Data ({len(df)} rows):\n{df.to_string()}" | |
| elif file_lower.endswith('.json'): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| return f"JSON Content:\n{json.dumps(data, indent=2)}" | |
| else: # Default: treat as text | |
| with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: | |
| content = f.read() | |
| return f"File Content:\n{content}" | |
| except Exception as e: | |
| return f"Error reading file: {str(e)}" | |
| def calculator(expression: str) -> str: | |
| """ | |
| Evaluate a mathematical expression safely. | |
| Supports: basic arithmetic, trigonometry, logarithms, exponents, etc. | |
| Args: | |
| expression: Mathematical expression (e.g., "sqrt(16) + log(100, 10)") | |
| Returns: | |
| The numerical result as a string | |
| """ | |
| try: | |
| import math | |
| # Define allowed functions and constants | |
| safe_dict = { | |
| 'abs': abs, 'round': round, 'min': min, 'max': max, | |
| 'sum': sum, 'pow': pow, 'len': len, | |
| 'sqrt': math.sqrt, 'log': math.log, 'log10': math.log10, | |
| 'log2': math.log2, 'exp': math.exp, | |
| 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, | |
| 'asin': math.asin, 'acos': math.acos, 'atan': math.atan, | |
| 'sinh': math.sinh, 'cosh': math.cosh, 'tanh': math.tanh, | |
| 'ceil': math.ceil, 'floor': math.floor, | |
| 'pi': math.pi, 'e': math.e, | |
| 'factorial': math.factorial, 'gcd': math.gcd, | |
| 'degrees': math.degrees, 'radians': math.radians, | |
| } | |
| result = eval(expression, {"__builtins__": {}}, safe_dict) | |
| # Format nicely | |
| if isinstance(result, float): | |
| if result.is_integer(): | |
| return str(int(result)) | |
| return f"{result:.10g}" # Remove trailing zeros | |
| return str(result) | |
| except Exception as e: | |
| return f"Calculation error: {str(e)}. Check your expression syntax." | |
| def wikipedia_search(query: str) -> str: | |
| """ | |
| Search Wikipedia for factual information about a specific topic. | |
| Best for: historical facts, biographies, scientific concepts, definitions. | |
| Args: | |
| query: The topic to search for on Wikipedia | |
| Returns: | |
| Summary and key information from relevant Wikipedia articles | |
| """ | |
| try: | |
| import urllib.parse | |
| # Search for articles | |
| search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={urllib.parse.quote(query)}&format=json&srlimit=3" | |
| response = requests.get(search_url, timeout=10) | |
| data = response.json() | |
| if 'query' not in data or 'search' not in data['query'] or not data['query']['search']: | |
| return f"No Wikipedia articles found for '{query}'" | |
| # Get full content of top result | |
| top_title = data['query']['search'][0]['title'] | |
| content_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=true&explaintext=true&titles={urllib.parse.quote(top_title)}&format=json" | |
| content_response = requests.get(content_url, timeout=10) | |
| content_data = content_response.json() | |
| pages = content_data.get('query', {}).get('pages', {}) | |
| for page_id, page_data in pages.items(): | |
| if page_id != '-1': | |
| title = page_data.get('title', '') | |
| extract = page_data.get('extract', 'No content available') | |
| return f"Wikipedia: {title}\n\n{extract[:2000]}" | |
| return "Could not retrieve article content." | |
| except Exception as e: | |
| return f"Wikipedia search failed: {str(e)}" | |
| def analyze_image(image_path: str, question: str) -> str: | |
| """ | |
| Analyze an image file and answer questions about it. | |
| Note: This is a placeholder - implement with vision model if needed. | |
| Args: | |
| image_path: Path to the image file | |
| question: What to analyze or find in the image | |
| Returns: | |
| Description or analysis of the image | |
| """ | |
| # This is a placeholder - you can integrate with GPT-4V or other vision models | |
| return f"Image analysis not implemented. File: {image_path}, Question: {question}" | |
| # Collect all tools | |
| TOOLS = [web_search, python_executor, read_file, calculator, wikipedia_search] | |
| # ============ SYSTEM PROMPT ============ | |
| SYSTEM_PROMPT = """You are an expert AI assistant designed to solve GAIA benchmark questions with maximum accuracy. | |
| ## Your Mission | |
| Provide PRECISE, EXACT answers. The benchmark uses EXACT STRING MATCHING, so your final answer must match the ground truth character-for-character. | |
| ## Critical Answer Formatting Rules (MUST FOLLOW) | |
| **DO NOT include "FINAL ANSWER:" or any prefix - just the answer itself.** | |
| 1. **Numbers**: Give just the number. | |
| - ✅ CORRECT: "42" | |
| - ❌ WRONG: "The answer is 42", "42 units", "Answer: 42" | |
| 2. **Names**: Exact spelling as found in sources. Check Wikipedia/official sources for correct spelling, capitalization, and punctuation. | |
| - ✅ CORRECT: "John Smith" | |
| - ❌ WRONG: "john smith", "John smith" | |
| 3. **Lists**: Comma-separated, NO spaces after commas. | |
| - ✅ CORRECT: "apple,banana,cherry" | |
| - ❌ WRONG: "apple, banana, cherry", "apple,banana, cherry" | |
| 4. **Dates**: Use the format specified in the question, or YYYY-MM-DD if not specified. | |
| - ✅ CORRECT: "2024-01-15" or "January 15, 2024" (if question asks for that format) | |
| - ❌ WRONG: "1/15/2024" (unless question asks for it) | |
| 5. **Yes/No**: Just "Yes" or "No" (capitalized, no period). | |
| - ✅ CORRECT: "Yes" | |
| - ❌ WRONG: "yes", "Yes.", "The answer is Yes" | |
| 6. **Counts**: Just the number. | |
| - ✅ CORRECT: "5" | |
| - ❌ WRONG: "5 items", "five", "There are 5" | |
| 7. **No explanations**: Your final response must contain ONLY the answer, nothing else. | |
| - ✅ CORRECT: "Paris" | |
| - ❌ WRONG: "The answer is Paris because..." | |
| ## Problem-Solving Strategy | |
| 1. **Understand**: Read the question carefully. What exactly is being asked? Note any specific format requirements. | |
| 2. **Check for File**: If a file is mentioned or available, ALWAYS read it FIRST - the answer is likely there. | |
| 3. **Plan**: What information do I need? Which tools should I use? | |
| 4. **Execute**: Use tools systematically. Verify information from multiple sources when possible. | |
| 5. **Verify**: Double-check your answer format. Does it match the question's requirements? Is spelling correct? | |
| 6. **Respond**: Give ONLY the final answer, no prefixes, no explanations. | |
| ## Available Tools | |
| - `read_file`: Read PDFs, spreadsheets, text files - USE THIS FIRST if a file is available | |
| - `web_search`: Current information, recent events, facts | |
| - `wikipedia_search`: Historical facts, biographies, definitions | |
| - `python_executor`: Calculations, data processing, analysis | |
| - `calculator`: Quick mathematical calculations | |
| ## Tool Usage Priority | |
| 1. **If file available**: Read file FIRST before doing anything else | |
| 2. **For calculations**: Use python_executor for complex math, calculator for simple expressions | |
| 3. **For facts**: Use wikipedia_search for established facts, web_search for current/recent information | |
| 4. **Cross-reference**: When possible, verify important facts from multiple sources | |
| ## Critical Reminders | |
| - NEVER include "FINAL ANSWER:" or any prefix in your response | |
| - NEVER add explanations or context to your final answer | |
| - ALWAYS verify spelling, capitalization, and formatting | |
| - ALWAYS read files first if they are available | |
| - If uncertain about format, look for clues in the question itself | |
| - Never guess - use tools to find accurate information | |
| Remember: Your final message must contain ONLY the answer, nothing else. The scoring system uses exact string matching.""" | |
| # ============ LANGGRAPH AGENT ============ | |
| class GAIAAgent: | |
| """LangGraph-based agent for GAIA benchmark.""" | |
| def __init__( | |
| self, | |
| model_name: str = "gpt-4o", | |
| api_key: str = None, | |
| temperature: float = 0, | |
| max_iterations: int = 15 | |
| ): | |
| """ | |
| Initialize the GAIA agent. | |
| Args: | |
| model_name: OpenAI model to use | |
| api_key: OpenAI API key (or set OPENAI_API_KEY env var) | |
| temperature: Model temperature (0 for deterministic) | |
| max_iterations: Maximum tool-use iterations | |
| """ | |
| self.model_name = model_name | |
| self.max_iterations = max_iterations | |
| self.llm = ChatOpenAI( | |
| model=model_name, | |
| temperature=temperature, | |
| api_key=api_key or os.environ.get("OPENAI_API_KEY") | |
| ) | |
| self.llm_with_tools = self.llm.bind_tools(TOOLS) | |
| self.graph = self._build_graph() | |
| def _build_graph(self) -> StateGraph: | |
| """Construct the LangGraph workflow.""" | |
| workflow = StateGraph(AgentState) | |
| # Define nodes | |
| workflow.add_node("agent", self._agent_node) | |
| workflow.add_node("tools", ToolNode(TOOLS)) | |
| workflow.add_node("extract_answer", self._extract_answer_node) | |
| # Set entry point | |
| workflow.set_entry_point("agent") | |
| # Define edges | |
| workflow.add_conditional_edges( | |
| "agent", | |
| self._route_agent_output, | |
| { | |
| "tools": "tools", | |
| "end": "extract_answer" | |
| } | |
| ) | |
| workflow.add_edge("tools", "agent") | |
| workflow.add_edge("extract_answer", END) | |
| return workflow.compile() | |
| def _agent_node(self, state: AgentState) -> dict: | |
| """Process messages and decide on next action.""" | |
| messages = state["messages"] | |
| iteration = state.get("iteration_count", 0) | |
| # Add iteration warnings earlier to give agent more time to finish | |
| if iteration >= self.max_iterations - 3: | |
| warning_msg = "WARNING: Approaching iteration limit. Please provide your final answer now. Remember: just the answer, no prefix." | |
| messages = list(messages) + [SystemMessage(content=warning_msg)] | |
| elif iteration >= self.max_iterations - 5: | |
| reminder_msg = "Reminder: When you're ready to answer, provide ONLY the final answer with no prefix like 'FINAL ANSWER:' or 'The answer is:'" | |
| messages = list(messages) + [SystemMessage(content=reminder_msg)] | |
| try: | |
| response = self.llm_with_tools.invoke(messages) | |
| except Exception as e: | |
| # Graceful error handling | |
| error_msg = AIMessage(content=f"Error during reasoning: {str(e)}. Please try a different approach or provide your best answer.") | |
| return { | |
| "messages": [error_msg], | |
| "iteration_count": iteration + 1 | |
| } | |
| return { | |
| "messages": [response], | |
| "iteration_count": iteration + 1 | |
| } | |
| def _route_agent_output(self, state: AgentState) -> Literal["tools", "end"]: | |
| """Determine whether to use tools or finish.""" | |
| last_message = state["messages"][-1] | |
| iteration = state.get("iteration_count", 0) | |
| # Force end if max iterations reached | |
| if iteration >= self.max_iterations: | |
| return "end" | |
| # Check if agent wants to use tools | |
| if hasattr(last_message, "tool_calls") and last_message.tool_calls: | |
| return "tools" | |
| return "end" | |
| def _extract_answer_node(self, state: AgentState) -> dict: | |
| """Extract and clean the final answer.""" | |
| last_message = state["messages"][-1] | |
| content = last_message.content if hasattr(last_message, "content") else str(last_message) | |
| answer = self._clean_answer(content) | |
| return {"final_answer": answer} | |
| def _clean_answer(self, raw_answer: str) -> str: | |
| """Clean and format the final answer for exact matching.""" | |
| answer = raw_answer.strip() | |
| # Remove common prefixes (case-insensitive, with variations) | |
| prefixes = [ | |
| "the answer is:", "the answer is", "answer is:", | |
| "answer:", "answer", "answer:", | |
| "final answer:", "final answer", "FINAL ANSWER:", "FINAL ANSWER", | |
| "the final answer is:", "the final answer is", | |
| "result:", "result", "result is:", | |
| "solution:", "solution", "solution is:", | |
| "the solution is:", "the solution is", | |
| "it is", "it's", "that is", "that's", | |
| ] | |
| answer_lower = answer.lower() | |
| for prefix in prefixes: | |
| if answer_lower.startswith(prefix): | |
| answer = answer[len(prefix):].strip() | |
| # Remove any leading colon or dash | |
| answer = answer.lstrip(':').lstrip('-').strip() | |
| answer_lower = answer.lower() | |
| # Remove quotes if they wrap the entire answer | |
| if (answer.startswith('"') and answer.endswith('"')) or \ | |
| (answer.startswith("'") and answer.endswith("'")): | |
| answer = answer[1:-1].strip() | |
| # Remove trailing periods, commas, or semicolons for single-word/number answers | |
| if answer and ' ' not in answer: | |
| answer = answer.rstrip('.,;:') | |
| # Remove leading/trailing whitespace and normalize internal whitespace | |
| answer = ' '.join(answer.split()) | |
| # Remove markdown formatting if present | |
| if answer.startswith('**') and answer.endswith('**'): | |
| answer = answer[2:-2] | |
| if answer.startswith('*') and answer.endswith('*'): | |
| answer = answer[1:-1] | |
| return answer.strip() | |
| def run(self, question: str, task_id: str = "", file_path: str = None) -> str: | |
| """ | |
| Run the agent on a question. | |
| Args: | |
| question: The GAIA question to answer | |
| task_id: Optional task identifier | |
| file_path: Optional path to associated file | |
| Returns: | |
| The agent's final answer | |
| """ | |
| # Prepare the user message with file priority | |
| user_content = question | |
| if file_path and os.path.exists(file_path): | |
| # Strongly emphasize reading the file first | |
| user_content = f"[IMPORTANT: A file is available at {file_path}]\n\nYou MUST read this file FIRST using the read_file tool before attempting to answer. The answer is very likely contained in this file.\n\nQuestion: {question}" | |
| # Initialize state | |
| initial_state: AgentState = { | |
| "messages": [ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=user_content) | |
| ], | |
| "task_id": task_id, | |
| "file_path": file_path, | |
| "file_content": None, | |
| "iteration_count": 0, | |
| "final_answer": None | |
| } | |
| # Execute the graph | |
| try: | |
| final_state = self.graph.invoke( | |
| initial_state, | |
| {"recursion_limit": self.max_iterations * 2 + 5} | |
| ) | |
| answer = final_state.get("final_answer", "Unable to determine answer") | |
| # Final validation - ensure answer is not empty or error message | |
| if not answer or answer.startswith("Agent error:") or answer.startswith("Unable to determine"): | |
| # Try to extract from last message if available | |
| if final_state.get("messages"): | |
| last_msg = final_state["messages"][-1] | |
| if hasattr(last_msg, "content") and last_msg.content: | |
| answer = self._clean_answer(last_msg.content) | |
| return answer if answer else "Unable to determine answer" | |
| except Exception as e: | |
| # Log error for debugging but return a clean error message | |
| import logging | |
| logging.error(f"Agent execution error: {str(e)}") | |
| return f"Agent error: {str(e)}" | |
| # ============ UTILITY FUNCTIONS ============ | |
| def create_agent(api_key: str = None, model: str = "gpt-4o") -> GAIAAgent: | |
| """Factory function to create a configured agent.""" | |
| return GAIAAgent( | |
| model_name=model, | |
| api_key=api_key, | |
| temperature=0, | |
| max_iterations=15 | |
| ) | |