Spaces:
Sleeping
Sleeping
| # ============================================================================== | |
| # Imports | |
| # ============================================================================== | |
| import os | |
| import requests | |
| import traceback | |
| import html2text # For HTML to text conversion | |
| import tempfile # For file handling tools | |
| import pandas as pd # For CSV/Excel analysis | |
| import openpyxl # For Excel analysis | |
| from PIL import Image # For image text extraction | |
| import pytesseract # For image text extraction | |
| from urllib.parse import urlparse # For download tool | |
| from typing import Annotated, List, TypedDict, Optional | |
| from dotenv import load_dotenv | |
| import time # For adding potential delays if needed later | |
| # LangChain and LangGraph Imports | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage | |
| from langchain_core.tools import tool | |
| # LLM Import - Using Groq | |
| from langchain_groq import ChatGroq | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| # ============================================================================== | |
| # Environment Setup & LLM | |
| # ============================================================================== | |
| load_dotenv() | |
| tavily_api_key = os.getenv("TAVILY_API_KEY") | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| # --- Optional: Tesseract Path --- | |
| # If Tesseract OCR is not in your system's PATH environment variable, | |
| # uncomment the following line and set the correct path to tesseract.exe | |
| # try: | |
| # pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Example path for Windows | |
| # except NameError: pass # Handles case where pytesseract might not be imported yet if PIL fails first | |
| # except Exception as e: print(f"Warning: Could not set tesseract_cmd path: {e}") | |
| # --- Validate API Keys --- | |
| if not tavily_api_key: | |
| raise ValueError("TAVILY_API_KEY not found in environment variables/Space secrets.") | |
| if not groq_api_key: | |
| raise ValueError("GROQ_API_KEY not found in environment variables/Space secrets.") | |
| # --- Initialize LLM (Using Groq) --- | |
| try: | |
| llm = ChatGroq( | |
| model="meta-llama/llama-4-maverick-17b-128e-instruct", # Powerful model available on Groq, good for reasoning | |
| # model="gemma2-9b-it", # Alternative lighter model | |
| api_key=groq_api_key, | |
| temperature=0.3 # Low temperature for factual tasks | |
| ) | |
| print(f"LLM Initialized: Groq - {llm.model_name}") | |
| except Exception as e: | |
| print(f"ERROR initializing Groq LLM: {e}") | |
| traceback.print_exc() | |
| raise # Stop if LLM fails to init | |
| # ============================================================================== | |
| # State Definition | |
| # ============================================================================== | |
| class AgentState(TypedDict): | |
| """Defines the structure of the information the agent tracks during its run.""" | |
| input_question: str # The original question from the benchmark | |
| messages: Annotated[List[BaseMessage], add_messages] # History of interactions (Human, AI, Tool) | |
| error: Optional[str] # Stores any error message encountered | |
| iterations: int # Counter for agent steps to prevent loops | |
| # ============================================================================== | |
| # Tools Definitions | |
| # ============================================================================== | |
| print("Defining tools...") | |
| # --- Search Tool (Tavily) --- | |
| search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key) | |
| search_tool.name = "web_search" | |
| search_tool.description = "Performs a web search (using Tavily) to find relevant URLs/snippets for a query." | |
| # --- Web Browser Tool (html2text) --- | |
| def web_browser(url: str) -> str: | |
| """Fetches text content from a webpage URL using html2text. Use after 'web_search'.""" | |
| print(f"--- [Tool] Browsing (html2text): {url} ---") | |
| try: | |
| headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'} | |
| response = requests.get(url, headers=headers, timeout=20) | |
| response.raise_for_status() | |
| response.encoding = response.apparent_encoding or 'utf-8' | |
| # Configure html2text | |
| h = html2text.HTML2Text(bodywidth=0) | |
| h.ignore_links = True | |
| h.ignore_images = True | |
| # Convert HTML to text | |
| clean_text = h.handle(response.text) | |
| # Limit content length | |
| max_length = 6000 | |
| if len(clean_text) > max_length: | |
| return clean_text[:max_length] + "\n\n... [Content Truncated]" | |
| cleaned_and_stripped = clean_text.strip() | |
| return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}." | |
| except requests.exceptions.RequestException as e: | |
| return f"Error: Network request failed for URL: {url}. Reason: {e}" | |
| except Exception as e: | |
| return f"Error: Unexpected error processing URL with html2text: {url}. Reason: {str(e)}" | |
| # --- File Download Tool --- | |
| def download_file_from_url(url: str, filename: Optional[str] = None) -> str: | |
| """Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error.""" | |
| print(f"--- [Tool] Downloading file from: {url} ---") | |
| try: | |
| # Generate filename if needed | |
| if not filename: | |
| try: path = urlparse(url).path; filename = os.path.basename(path) if path else None | |
| except Exception: filename = None | |
| if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}" | |
| # Define save path | |
| temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename) | |
| # Download file | |
| response = requests.get(url, stream=True, timeout=30); response.raise_for_status() | |
| with open(filepath, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): f.write(chunk) | |
| print(f"--- [Tool] File downloaded to: {filepath} ---") | |
| return f"File downloaded to {filepath}. Use appropriate tools (e.g., analyze_csv_file) to process it." | |
| except requests.exceptions.RequestException as e: | |
| return f"Error downloading file: Network issue for {url}. Reason: {e}" | |
| except Exception as e: | |
| return f"Error downloading file: Unexpected error for {url}. Reason: {str(e)}" | |
| # --- CSV Analysis Tool --- | |
| def analyze_csv_file(file_path: str) -> str: | |
| """Analyzes a CSV file at the given path using pandas. Returns a summary of content or error.""" | |
| print(f"--- [Tool] Analyzing CSV: {file_path} ---") | |
| # GAIA might provide relative paths, ensure they work or adjust logic if needed | |
| if not os.path.exists(file_path): return f"Error: CSV file not found at path: {file_path}" | |
| try: | |
| df = pd.read_csv(file_path) | |
| # Generate summary string | |
| summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n" | |
| summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n" | |
| summary += f"- Columns: {', '.join(df.columns)}\n" | |
| summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n" | |
| numeric_cols = df.select_dtypes(include=['number']) | |
| if not numeric_cols.empty: | |
| summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}" | |
| else: | |
| summary += "\nNo numeric columns for stats." | |
| return summary | |
| except ImportError: return "Error: 'pandas' required but not installed." | |
| except Exception as e: return f"Error analyzing CSV {file_path}: {str(e)}" | |
| # --- Excel Analysis Tool --- | |
| def analyze_excel_file(file_path: str) -> str: | |
| """Analyzes an Excel file (.xlsx, .xls) at the given path. Returns a summary of the first sheet or error.""" | |
| print(f"--- [Tool] Analyzing Excel: {file_path} ---") | |
| if not os.path.exists(file_path): return f"Error: Excel file not found at path: {file_path}" | |
| try: | |
| df = pd.read_excel(file_path, engine='openpyxl') | |
| # Generate summary string | |
| summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n" | |
| summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n" | |
| summary += f"- Columns: {', '.join(df.columns)}\n" | |
| summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n" | |
| numeric_cols = df.select_dtypes(include=['number']) | |
| if not numeric_cols.empty: | |
| summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}" | |
| else: | |
| summary += "\nNo numeric columns for stats." | |
| return summary | |
| except ImportError: return "Error: 'pandas' and 'openpyxl' required but not installed." | |
| except Exception as e: return f"Error analyzing Excel {file_path}: {str(e)}" | |
| # --- Image Text Extraction Tool (OCR) --- | |
| def extract_text_from_image(file_path: str) -> str: | |
| """Extracts text from an image file at the given path using Tesseract OCR. Returns extracted text or error.""" | |
| print(f"--- [Tool] Extracting text from image: {file_path} ---") | |
| if not os.path.exists(file_path): return f"Error: Image file not found at path: {file_path}" | |
| try: | |
| # Need to explicitly handle potential empty string from pytesseract | |
| text = pytesseract.image_to_string(Image.open(file_path)) | |
| text_stripped = text.strip() | |
| # Return a clear message if no text found, otherwise return extracted text | |
| return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image." | |
| except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed." | |
| except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH." | |
| except Exception as e: return f"Error extracting text from image {file_path}: {str(e)}" | |
| # --- Basic Math Tools --- | |
| def add(a: float, b: float) -> float: | |
| """Adds two numbers (a + b). Handles float inputs.""" | |
| print(f"--- [Tool] Calculating: {a} + {b} ---") | |
| return a + b | |
| def subtract(a: float, b: float) -> float: | |
| """Subtracts the second number from the first (a - b). Handles float inputs.""" | |
| print(f"--- [Tool] Calculating: {a} - {b} ---") | |
| return a - b | |
| def multiply(a: float, b: float) -> float: | |
| """Multiplies two numbers (a * b). Handles float inputs.""" | |
| print(f"--- [Tool] Calculating: {a} * {b} ---") | |
| return a * b | |
| def divide(a: float, b: float) -> float | str: | |
| """Divides the first number by the second (a / b). Handles float inputs and division by zero.""" | |
| print(f"--- [Tool] Calculating: {a} / {b} ---") | |
| if b == 0: return "Error: Cannot divide by zero." | |
| return a / b | |
| # --- Compile list of all tools --- | |
| tools = [ search_tool, web_browser, download_file_from_url, analyze_csv_file, | |
| analyze_excel_file, extract_text_from_image, add, subtract, multiply, divide ] | |
| # --- Bind tools to the LLM --- | |
| # Ensure LLM is initialized before binding | |
| if 'llm' not in globals(): | |
| raise RuntimeError("LLM was not initialized successfully before binding tools.") | |
| llm_with_tools = llm.bind_tools(tools) | |
| print(f"Agent initialized with {len(tools)} tools.") | |
| # ============================================================================== | |
| # Node Definitions (With Logging Added) | |
| # ============================================================================== | |
| print("Defining graph nodes...") | |
| # --- Agent Node --- | |
| def call_agent_node(state: AgentState) -> dict: | |
| """Invokes the LLM with current state to decide the next step.""" | |
| # --- Logging: Node Entry --- | |
| print(f"\n>>> Entering Agent Node (Iteration {state['iterations']})") | |
| MAX_ITERATIONS = 15 # Max steps allowed for the task - Increased slightly | |
| current_iterations = state.get('iterations', 0) | |
| if current_iterations >= MAX_ITERATIONS: | |
| print(f"!!! Agent Node: Max iterations ({MAX_ITERATIONS}) reached. Setting error.") | |
| return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."} | |
| try: | |
| print(f"--- Agent Node: Invoking LLM ({llm.model_name})... ---") # Log before LLM call | |
| # Ensure LLM is bound with tools before invoking | |
| if 'llm_with_tools' not in globals(): | |
| return {"error": "LLM tools not bound."} | |
| response = llm_with_tools.invoke(state['messages']) | |
| print(f"--- Agent Node: LLM Invocation Complete. ---") # Log after LLM call | |
| # response.pretty_print() # Optional: Print formatted LLM response | |
| # --- Logging: Node Exit (Success) --- | |
| print(f"<<< Exiting Agent Node (Success, Iteration {current_iterations + 1})") | |
| return {"messages": [response], "iterations": current_iterations + 1} | |
| except Exception as e: | |
| error_message = f"LLM invocation failed: {str(e)}" | |
| print(f"!!! Agent Node ERROR: {error_message} !!!") | |
| traceback.print_exc() # Print full traceback for debugging LLM errors | |
| # --- Logging: Node Exit (Error) --- | |
| print(f"<<< Exiting Agent Node (LLM Error, Iteration {current_iterations})") | |
| # Return an error message and set error state, still increment iteration to prevent infinite error loops | |
| return {"messages": [AIMessage(content=f"Error during LLM call: {error_message}")], "error": error_message, "iterations": current_iterations + 1} | |
| # --- Tool Node Wrapper (for Logging) --- | |
| # We still use the prebuilt ToolNode, but wrap its call for logging | |
| tool_executor = ToolNode(tools) # Keep the instance | |
| def logged_tool_node(state: AgentState) -> dict: | |
| """Logs tool execution start/end and calls the actual ToolNode.""" | |
| print(f">>> Entering Tool Node") | |
| # Log requested tools | |
| last_message = state['messages'][-1] | |
| requested_tools_str = "None" | |
| tool_calls = [] | |
| if hasattr(last_message, "tool_calls") and last_message.tool_calls: | |
| tool_calls = last_message.tool_calls | |
| tool_names = [tc.get('name', 'unknown') for tc in tool_calls] | |
| requested_tools_str = ", ".join(tool_names) | |
| print(f"--- Tool Node: Executing tools: {requested_tools_str} ---") | |
| if tool_calls: print(f"--- Tool Node: Tool Args: {[tc.get('args') for tc in tool_calls]} ---") | |
| try: | |
| # Call the actual ToolNode instance | |
| result = tool_executor.invoke(state) | |
| # Log truncated results | |
| print("--- Tool Node: Tool Execution Results ---") | |
| if isinstance(result.get("messages"), list): | |
| for msg in result["messages"]: | |
| if isinstance(msg, ToolMessage): | |
| print(f" - Tool: {msg.name}, Result (start): {str(msg.content)[:200]}...") # Slightly more context | |
| print(f"<<< Exiting Tool Node (Success)") | |
| return result # Return the dictionary containing ToolMessages | |
| except Exception as e: | |
| error_message = f"ToolNode invocation exception: {str(e)}" | |
| print(f"!!! Tool Node ERROR: {error_message} !!!") | |
| traceback.print_exc() | |
| print(f"<<< Exiting Tool Node (Error)") | |
| # Return an error message in the expected format | |
| return {"messages": [ToolMessage(content=error_message, tool_call_id="tool_node_error")]} | |
| # ============================================================================== | |
| # Graph Construction (Non-conversational, using logged tool node) | |
| # ============================================================================== | |
| print("Building agent graph...") | |
| builder = StateGraph(AgentState) | |
| builder.add_node("agent", call_agent_node) | |
| builder.add_node("tools", logged_tool_node) # Use the logging wrapper node | |
| builder.add_edge(START, "agent") | |
| builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", END: END}) | |
| builder.add_edge("tools", "agent") | |
| # Compile the graph globally so it's ready for the function call | |
| try: | |
| graph = builder.compile() | |
| print("GAIA agent graph compiled successfully.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to compile LangGraph graph: {e}") | |
| traceback.print_exc() | |
| graph = None # Ensure graph is None if compilation fails | |
| raise # Reraise exception to make startup failure clear | |
| # ============================================================================== | |
| # Main Execution Function for GAIA Benchmark <<<< WRAPPER FUNCTION >>>> | |
| # ============================================================================== | |
| def answer_gaia_task(question: str, file_path: Optional[str] = None) -> str: | |
| """ | |
| Runs the compiled GAIA agent graph for a given question and optional file path. | |
| This is the main entry point expected by the benchmark runner. | |
| """ | |
| # Check if graph compilation was successful | |
| if graph is None: | |
| return "Error: Agent graph was not compiled successfully during setup." | |
| print(f"\n{'='*20} Running Agent for GAIA Task {'='*20}") | |
| print(f"Question: {question}") | |
| file_context_info = f"An associated file is provided at path: '{file_path}'. Use this path if relevant." if file_path else "" | |
| # Define the initial prompt sent to the agent, incorporating strict formatting rules | |
| prompt_content = f"""Your task is to accurately answer the following question based *only* on information obtained using your tools (web search, web browser, file download, csv/excel analysis, image OCR, math). | |
| {file_context_info} | |
| Follow these steps methodically: | |
| 1. Analyze the question: {question} | |
| 2. Use tools ONLY if necessary to gather the specific information required. Assume local file paths mentioned (like 'data.csv') are accessible. | |
| 3. Synthesize the final answer from the gathered information. | |
| **CRITICAL OUTPUT FORMATTING RULES:** | |
| * Your final response MUST be ONLY the answer, without any other text/explanations. | |
| * **Numbers:** No commas (1000). No units ($ , %) unless asked. | |
| * **Strings:** No articles (a, an, the) unless proper noun. No abbreviations (Saint Petersburg) unless answer is abbreviation. Use numerals (5). | |
| * **Lists:** Comma-separated (apple,banana,cherry). Apply number/string rules to elements. | |
| * If answer not found, output only the exact phrase: Information not found | |
| Provide ONLY the final answer according to these rules. | |
| """ | |
| # Create the initial state for the graph run | |
| initial_state = AgentState( | |
| input_question=question, | |
| messages=[HumanMessage(content=prompt_content)], | |
| error=None, | |
| iterations=0 | |
| ) | |
| final_answer = "Error: Agent execution did not complete successfully." # Default fallback | |
| try: | |
| # Invoke the compiled graph | |
| final_state = graph.invoke(initial_state, config={"recursion_limit": 20}) # Increased recursion limit | |
| # Process the final state to extract the answer | |
| if final_state: | |
| # Prioritize showing agent error if one occurred | |
| if final_state.get("error"): | |
| print(f"--- Agent stopped due to ERROR: {final_state['error']} ---") | |
| final_answer = f"Error: {final_state['error']}" | |
| # Otherwise, try to get the last AI message content | |
| elif final_state.get('messages') and isinstance(final_state['messages'][-1], AIMessage): | |
| potential_answer = final_state['messages'][-1].content | |
| # Basic cleanup for potential quotes added by LLM | |
| if isinstance(potential_answer, str): | |
| if (potential_answer.startswith('"') and potential_answer.endswith('"')) or \ | |
| (potential_answer.startswith("'") and potential_answer.endswith("'")): | |
| potential_answer = potential_answer[1:-1].strip() | |
| print(f"--- Final Answer (from AI): {potential_answer} ---") | |
| final_answer = potential_answer | |
| else: | |
| print("--- Could not determine final answer (last message not AI or missing). Check logs. ---") | |
| # Log final state details for debugging | |
| print(f"Final State: Error={final_state.get('error')}, Iterations={final_state.get('iterations')}") | |
| except Exception as e: | |
| print(f"--- Graph execution failed unexpectedly: {e} ---") | |
| traceback.print_exc() | |
| final_answer = f"Error: Graph execution failed - {str(e)}" | |
| print(f"{'='*20} Agent Run Finished {'='*20}") | |
| # Return the final answer string | |
| return str(final_answer) | |
| # ============================================================================== | |
| # Local Testing Block (Optional) | |
| # ============================================================================== | |
| # This block allows you to test the agent by running final_agent.py directly. | |
| if __name__ == "__main__": | |
| print("\n--- Running Local Test ---") | |
| # --- Define Test Question --- | |
| test_question = "What is the result of multiplying the number of rows (excluding the header) in 'data.csv' by the number found after the phrase 'total items:' in 'image.png'?" | |
| # --- Create Dummy Files for Local Test --- | |
| print("Creating dummy files for local test...") | |
| dummy_files_created = True | |
| try: | |
| # Dummy CSV with 3 data rows + header | |
| with open("data.csv", "w") as f: | |
| f.write("Header1,Header2\nRow1Val1,Row1Val2\nRow2Val1,Row2Val2\nRow3Val1,Row3Val2") | |
| # Dummy Image containing the required text | |
| try: | |
| img = Image.new('RGB', (300, 50), color = (255, 255, 255)) # White background | |
| from PIL import ImageDraw, ImageFont # Import drawing tools locally | |
| draw = ImageDraw.Draw(img) | |
| # Use a basic font if specific ones aren't found | |
| try: font = ImageFont.truetype("arial.ttf", 15) | |
| except IOError: font = ImageFont.load_default() | |
| draw.text((10,10), "Some random info... total items: 7 ... more text", fill=(0,0,0), font=font) # Black text | |
| img.save("image.png") | |
| print("Dummy data.csv and image.png created successfully.") | |
| except ImportError: | |
| print("Pillow/ImageDraw/ImageFont not installed. Cannot create dummy image file.") | |
| dummy_files_created = False | |
| except Exception as img_e: | |
| print(f"Error creating dummy image: {img_e}") | |
| dummy_files_created = False | |
| except Exception as file_e: | |
| print(f"Error creating dummy files: {file_e}") | |
| dummy_files_created = False | |
| # --------------------------------------------- | |
| # --- Run the Test --- | |
| if dummy_files_created: | |
| # Call the main function, simulating how the benchmark runner would call it. | |
| result = answer_gaia_task(question=test_question, file_path=None) | |
| print(f"\n--- Local Test Result ---") | |
| print(f"Returned Answer: {result}") | |
| print(f"Expected Answer (for dummy files): 21") # 3 data rows * 7 = 21 | |
| else: | |
| print("Skipping test execution due to issues creating dummy files.") | |
| # --- Clean up Dummy Files --- | |
| print("\nCleaning up dummy files...") | |
| for dummy_file in ["data.csv", "image.png"]: | |
| if os.path.exists(dummy_file): | |
| try: os.remove(dummy_file) | |
| except Exception as e: print(f"Could not remove {dummy_file}: {e}") | |
| print("Dummy file cleanup attempted.") |