# ============================================================================== # Imports # ============================================================================== import os import requests import traceback import html2text # For HTML to text conversion import tempfile # For file handling tools import pandas as pd # For CSV/Excel analysis import openpyxl # For Excel analysis from PIL import Image # For image text extraction import pytesseract # For image text extraction from urllib.parse import urlparse # For download tool from typing import Annotated, List, TypedDict, Optional from dotenv import load_dotenv import time # For adding potential delays if needed later # LangChain and LangGraph Imports from langgraph.graph import StateGraph, START, END from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage from langchain_core.tools import tool # LLM Import - Using Groq from langchain_groq import ChatGroq from langchain_community.tools.tavily_search import TavilySearchResults # ============================================================================== # Environment Setup & LLM # ============================================================================== load_dotenv() tavily_api_key = os.getenv("TAVILY_API_KEY") groq_api_key = os.getenv("GROQ_API_KEY") # --- Optional: Tesseract Path --- # If Tesseract OCR is not in your system's PATH environment variable, # uncomment the following line and set the correct path to tesseract.exe # try: # pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Example path for Windows # except NameError: pass # Handles case where pytesseract might not be imported yet if PIL fails first # except Exception as e: print(f"Warning: Could not set tesseract_cmd path: {e}") # --- Validate API Keys --- if not tavily_api_key: raise ValueError("TAVILY_API_KEY not found in environment variables/Space secrets.") if not groq_api_key: raise ValueError("GROQ_API_KEY not found in environment variables/Space secrets.") # --- Initialize LLM (Using Groq) --- try: llm = ChatGroq( model="meta-llama/llama-4-maverick-17b-128e-instruct", # Powerful model available on Groq, good for reasoning # model="gemma2-9b-it", # Alternative lighter model api_key=groq_api_key, temperature=0.3 # Low temperature for factual tasks ) print(f"LLM Initialized: Groq - {llm.model_name}") except Exception as e: print(f"ERROR initializing Groq LLM: {e}") traceback.print_exc() raise # Stop if LLM fails to init # ============================================================================== # State Definition # ============================================================================== class AgentState(TypedDict): """Defines the structure of the information the agent tracks during its run.""" input_question: str # The original question from the benchmark messages: Annotated[List[BaseMessage], add_messages] # History of interactions (Human, AI, Tool) error: Optional[str] # Stores any error message encountered iterations: int # Counter for agent steps to prevent loops # ============================================================================== # Tools Definitions # ============================================================================== print("Defining tools...") # --- Search Tool (Tavily) --- search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key) search_tool.name = "web_search" search_tool.description = "Performs a web search (using Tavily) to find relevant URLs/snippets for a query." # --- Web Browser Tool (html2text) --- @tool def web_browser(url: str) -> str: """Fetches text content from a webpage URL using html2text. Use after 'web_search'.""" print(f"--- [Tool] Browsing (html2text): {url} ---") try: headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'} response = requests.get(url, headers=headers, timeout=20) response.raise_for_status() response.encoding = response.apparent_encoding or 'utf-8' # Configure html2text h = html2text.HTML2Text(bodywidth=0) h.ignore_links = True h.ignore_images = True # Convert HTML to text clean_text = h.handle(response.text) # Limit content length max_length = 6000 if len(clean_text) > max_length: return clean_text[:max_length] + "\n\n... [Content Truncated]" cleaned_and_stripped = clean_text.strip() return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}." except requests.exceptions.RequestException as e: return f"Error: Network request failed for URL: {url}. Reason: {e}" except Exception as e: return f"Error: Unexpected error processing URL with html2text: {url}. Reason: {str(e)}" # --- File Download Tool --- @tool def download_file_from_url(url: str, filename: Optional[str] = None) -> str: """Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error.""" print(f"--- [Tool] Downloading file from: {url} ---") try: # Generate filename if needed if not filename: try: path = urlparse(url).path; filename = os.path.basename(path) if path else None except Exception: filename = None if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}" # Define save path temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename) # Download file response = requests.get(url, stream=True, timeout=30); response.raise_for_status() with open(filepath, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"--- [Tool] File downloaded to: {filepath} ---") return f"File downloaded to {filepath}. Use appropriate tools (e.g., analyze_csv_file) to process it." except requests.exceptions.RequestException as e: return f"Error downloading file: Network issue for {url}. Reason: {e}" except Exception as e: return f"Error downloading file: Unexpected error for {url}. Reason: {str(e)}" # --- CSV Analysis Tool --- @tool def analyze_csv_file(file_path: str) -> str: """Analyzes a CSV file at the given path using pandas. Returns a summary of content or error.""" print(f"--- [Tool] Analyzing CSV: {file_path} ---") # GAIA might provide relative paths, ensure they work or adjust logic if needed if not os.path.exists(file_path): return f"Error: CSV file not found at path: {file_path}" try: df = pd.read_csv(file_path) # Generate summary string summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n" summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n" summary += f"- Columns: {', '.join(df.columns)}\n" summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n" numeric_cols = df.select_dtypes(include=['number']) if not numeric_cols.empty: summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}" else: summary += "\nNo numeric columns for stats." return summary except ImportError: return "Error: 'pandas' required but not installed." except Exception as e: return f"Error analyzing CSV {file_path}: {str(e)}" # --- Excel Analysis Tool --- @tool def analyze_excel_file(file_path: str) -> str: """Analyzes an Excel file (.xlsx, .xls) at the given path. Returns a summary of the first sheet or error.""" print(f"--- [Tool] Analyzing Excel: {file_path} ---") if not os.path.exists(file_path): return f"Error: Excel file not found at path: {file_path}" try: df = pd.read_excel(file_path, engine='openpyxl') # Generate summary string summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n" summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n" summary += f"- Columns: {', '.join(df.columns)}\n" summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n" numeric_cols = df.select_dtypes(include=['number']) if not numeric_cols.empty: summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}" else: summary += "\nNo numeric columns for stats." return summary except ImportError: return "Error: 'pandas' and 'openpyxl' required but not installed." except Exception as e: return f"Error analyzing Excel {file_path}: {str(e)}" # --- Image Text Extraction Tool (OCR) --- @tool def extract_text_from_image(file_path: str) -> str: """Extracts text from an image file at the given path using Tesseract OCR. Returns extracted text or error.""" print(f"--- [Tool] Extracting text from image: {file_path} ---") if not os.path.exists(file_path): return f"Error: Image file not found at path: {file_path}" try: # Need to explicitly handle potential empty string from pytesseract text = pytesseract.image_to_string(Image.open(file_path)) text_stripped = text.strip() # Return a clear message if no text found, otherwise return extracted text return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image." except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed." except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH." except Exception as e: return f"Error extracting text from image {file_path}: {str(e)}" # --- Basic Math Tools --- @tool def add(a: float, b: float) -> float: """Adds two numbers (a + b). Handles float inputs.""" print(f"--- [Tool] Calculating: {a} + {b} ---") return a + b @tool def subtract(a: float, b: float) -> float: """Subtracts the second number from the first (a - b). Handles float inputs.""" print(f"--- [Tool] Calculating: {a} - {b} ---") return a - b @tool def multiply(a: float, b: float) -> float: """Multiplies two numbers (a * b). Handles float inputs.""" print(f"--- [Tool] Calculating: {a} * {b} ---") return a * b @tool def divide(a: float, b: float) -> float | str: """Divides the first number by the second (a / b). Handles float inputs and division by zero.""" print(f"--- [Tool] Calculating: {a} / {b} ---") if b == 0: return "Error: Cannot divide by zero." return a / b # --- Compile list of all tools --- tools = [ search_tool, web_browser, download_file_from_url, analyze_csv_file, analyze_excel_file, extract_text_from_image, add, subtract, multiply, divide ] # --- Bind tools to the LLM --- # Ensure LLM is initialized before binding if 'llm' not in globals(): raise RuntimeError("LLM was not initialized successfully before binding tools.") llm_with_tools = llm.bind_tools(tools) print(f"Agent initialized with {len(tools)} tools.") # ============================================================================== # Node Definitions (With Logging Added) # ============================================================================== print("Defining graph nodes...") # --- Agent Node --- def call_agent_node(state: AgentState) -> dict: """Invokes the LLM with current state to decide the next step.""" # --- Logging: Node Entry --- print(f"\n>>> Entering Agent Node (Iteration {state['iterations']})") MAX_ITERATIONS = 15 # Max steps allowed for the task - Increased slightly current_iterations = state.get('iterations', 0) if current_iterations >= MAX_ITERATIONS: print(f"!!! Agent Node: Max iterations ({MAX_ITERATIONS}) reached. Setting error.") return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."} try: print(f"--- Agent Node: Invoking LLM ({llm.model_name})... ---") # Log before LLM call # Ensure LLM is bound with tools before invoking if 'llm_with_tools' not in globals(): return {"error": "LLM tools not bound."} response = llm_with_tools.invoke(state['messages']) print(f"--- Agent Node: LLM Invocation Complete. ---") # Log after LLM call # response.pretty_print() # Optional: Print formatted LLM response # --- Logging: Node Exit (Success) --- print(f"<<< Exiting Agent Node (Success, Iteration {current_iterations + 1})") return {"messages": [response], "iterations": current_iterations + 1} except Exception as e: error_message = f"LLM invocation failed: {str(e)}" print(f"!!! Agent Node ERROR: {error_message} !!!") traceback.print_exc() # Print full traceback for debugging LLM errors # --- Logging: Node Exit (Error) --- print(f"<<< Exiting Agent Node (LLM Error, Iteration {current_iterations})") # Return an error message and set error state, still increment iteration to prevent infinite error loops return {"messages": [AIMessage(content=f"Error during LLM call: {error_message}")], "error": error_message, "iterations": current_iterations + 1} # --- Tool Node Wrapper (for Logging) --- # We still use the prebuilt ToolNode, but wrap its call for logging tool_executor = ToolNode(tools) # Keep the instance def logged_tool_node(state: AgentState) -> dict: """Logs tool execution start/end and calls the actual ToolNode.""" print(f">>> Entering Tool Node") # Log requested tools last_message = state['messages'][-1] requested_tools_str = "None" tool_calls = [] if hasattr(last_message, "tool_calls") and last_message.tool_calls: tool_calls = last_message.tool_calls tool_names = [tc.get('name', 'unknown') for tc in tool_calls] requested_tools_str = ", ".join(tool_names) print(f"--- Tool Node: Executing tools: {requested_tools_str} ---") if tool_calls: print(f"--- Tool Node: Tool Args: {[tc.get('args') for tc in tool_calls]} ---") try: # Call the actual ToolNode instance result = tool_executor.invoke(state) # Log truncated results print("--- Tool Node: Tool Execution Results ---") if isinstance(result.get("messages"), list): for msg in result["messages"]: if isinstance(msg, ToolMessage): print(f" - Tool: {msg.name}, Result (start): {str(msg.content)[:200]}...") # Slightly more context print(f"<<< Exiting Tool Node (Success)") return result # Return the dictionary containing ToolMessages except Exception as e: error_message = f"ToolNode invocation exception: {str(e)}" print(f"!!! Tool Node ERROR: {error_message} !!!") traceback.print_exc() print(f"<<< Exiting Tool Node (Error)") # Return an error message in the expected format return {"messages": [ToolMessage(content=error_message, tool_call_id="tool_node_error")]} # ============================================================================== # Graph Construction (Non-conversational, using logged tool node) # ============================================================================== print("Building agent graph...") builder = StateGraph(AgentState) builder.add_node("agent", call_agent_node) builder.add_node("tools", logged_tool_node) # Use the logging wrapper node builder.add_edge(START, "agent") builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", END: END}) builder.add_edge("tools", "agent") # Compile the graph globally so it's ready for the function call try: graph = builder.compile() print("GAIA agent graph compiled successfully.") except Exception as e: print(f"ERROR: Failed to compile LangGraph graph: {e}") traceback.print_exc() graph = None # Ensure graph is None if compilation fails raise # Reraise exception to make startup failure clear # ============================================================================== # Main Execution Function for GAIA Benchmark <<<< WRAPPER FUNCTION >>>> # ============================================================================== def answer_gaia_task(question: str, file_path: Optional[str] = None) -> str: """ Runs the compiled GAIA agent graph for a given question and optional file path. This is the main entry point expected by the benchmark runner. """ # Check if graph compilation was successful if graph is None: return "Error: Agent graph was not compiled successfully during setup." print(f"\n{'='*20} Running Agent for GAIA Task {'='*20}") print(f"Question: {question}") file_context_info = f"An associated file is provided at path: '{file_path}'. Use this path if relevant." if file_path else "" # Define the initial prompt sent to the agent, incorporating strict formatting rules prompt_content = f"""Your task is to accurately answer the following question based *only* on information obtained using your tools (web search, web browser, file download, csv/excel analysis, image OCR, math). {file_context_info} Follow these steps methodically: 1. Analyze the question: {question} 2. Use tools ONLY if necessary to gather the specific information required. Assume local file paths mentioned (like 'data.csv') are accessible. 3. Synthesize the final answer from the gathered information. **CRITICAL OUTPUT FORMATTING RULES:** * Your final response MUST be ONLY the answer, without any other text/explanations. * **Numbers:** No commas (1000). No units ($ , %) unless asked. * **Strings:** No articles (a, an, the) unless proper noun. No abbreviations (Saint Petersburg) unless answer is abbreviation. Use numerals (5). * **Lists:** Comma-separated (apple,banana,cherry). Apply number/string rules to elements. * If answer not found, output only the exact phrase: Information not found Provide ONLY the final answer according to these rules. """ # Create the initial state for the graph run initial_state = AgentState( input_question=question, messages=[HumanMessage(content=prompt_content)], error=None, iterations=0 ) final_answer = "Error: Agent execution did not complete successfully." # Default fallback try: # Invoke the compiled graph final_state = graph.invoke(initial_state, config={"recursion_limit": 20}) # Increased recursion limit # Process the final state to extract the answer if final_state: # Prioritize showing agent error if one occurred if final_state.get("error"): print(f"--- Agent stopped due to ERROR: {final_state['error']} ---") final_answer = f"Error: {final_state['error']}" # Otherwise, try to get the last AI message content elif final_state.get('messages') and isinstance(final_state['messages'][-1], AIMessage): potential_answer = final_state['messages'][-1].content # Basic cleanup for potential quotes added by LLM if isinstance(potential_answer, str): if (potential_answer.startswith('"') and potential_answer.endswith('"')) or \ (potential_answer.startswith("'") and potential_answer.endswith("'")): potential_answer = potential_answer[1:-1].strip() print(f"--- Final Answer (from AI): {potential_answer} ---") final_answer = potential_answer else: print("--- Could not determine final answer (last message not AI or missing). Check logs. ---") # Log final state details for debugging print(f"Final State: Error={final_state.get('error')}, Iterations={final_state.get('iterations')}") except Exception as e: print(f"--- Graph execution failed unexpectedly: {e} ---") traceback.print_exc() final_answer = f"Error: Graph execution failed - {str(e)}" print(f"{'='*20} Agent Run Finished {'='*20}") # Return the final answer string return str(final_answer) # ============================================================================== # Local Testing Block (Optional) # ============================================================================== # This block allows you to test the agent by running final_agent.py directly. if __name__ == "__main__": print("\n--- Running Local Test ---") # --- Define Test Question --- test_question = "What is the result of multiplying the number of rows (excluding the header) in 'data.csv' by the number found after the phrase 'total items:' in 'image.png'?" # --- Create Dummy Files for Local Test --- print("Creating dummy files for local test...") dummy_files_created = True try: # Dummy CSV with 3 data rows + header with open("data.csv", "w") as f: f.write("Header1,Header2\nRow1Val1,Row1Val2\nRow2Val1,Row2Val2\nRow3Val1,Row3Val2") # Dummy Image containing the required text try: img = Image.new('RGB', (300, 50), color = (255, 255, 255)) # White background from PIL import ImageDraw, ImageFont # Import drawing tools locally draw = ImageDraw.Draw(img) # Use a basic font if specific ones aren't found try: font = ImageFont.truetype("arial.ttf", 15) except IOError: font = ImageFont.load_default() draw.text((10,10), "Some random info... total items: 7 ... more text", fill=(0,0,0), font=font) # Black text img.save("image.png") print("Dummy data.csv and image.png created successfully.") except ImportError: print("Pillow/ImageDraw/ImageFont not installed. Cannot create dummy image file.") dummy_files_created = False except Exception as img_e: print(f"Error creating dummy image: {img_e}") dummy_files_created = False except Exception as file_e: print(f"Error creating dummy files: {file_e}") dummy_files_created = False # --------------------------------------------- # --- Run the Test --- if dummy_files_created: # Call the main function, simulating how the benchmark runner would call it. result = answer_gaia_task(question=test_question, file_path=None) print(f"\n--- Local Test Result ---") print(f"Returned Answer: {result}") print(f"Expected Answer (for dummy files): 21") # 3 data rows * 7 = 21 else: print("Skipping test execution due to issues creating dummy files.") # --- Clean up Dummy Files --- print("\nCleaning up dummy files...") for dummy_file in ["data.csv", "image.png"]: if os.path.exists(dummy_file): try: os.remove(dummy_file) except Exception as e: print(f"Could not remove {dummy_file}: {e}") print("Dummy file cleanup attempted.")