Final_Assignment_Template / final_agent.py
Macmill's picture
Update final_agent.py
673dfa2 verified
raw
history blame
23.3 kB
# ==============================================================================
# Imports
# ==============================================================================
import os
import requests
import traceback
import html2text # For HTML to text conversion
import tempfile # For file handling tools
import pandas as pd # For CSV/Excel analysis
import openpyxl # For Excel analysis
from PIL import Image # For image text extraction
import pytesseract # For image text extraction
from urllib.parse import urlparse # For download tool
from typing import Annotated, List, TypedDict, Optional
from dotenv import load_dotenv
import time # For adding potential delays if needed later
# LangChain and LangGraph Imports
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
# LLM Import - Using Groq
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
# ==============================================================================
# Environment Setup & LLM
# ==============================================================================
load_dotenv()
tavily_api_key = os.getenv("TAVILY_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")
# --- Optional: Tesseract Path ---
# If Tesseract OCR is not in your system's PATH environment variable,
# uncomment the following line and set the correct path to tesseract.exe
# try:
# pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Example path for Windows
# except NameError: pass # Handles case where pytesseract might not be imported yet if PIL fails first
# except Exception as e: print(f"Warning: Could not set tesseract_cmd path: {e}")
# --- Validate API Keys ---
if not tavily_api_key:
raise ValueError("TAVILY_API_KEY not found in environment variables/Space secrets.")
if not groq_api_key:
raise ValueError("GROQ_API_KEY not found in environment variables/Space secrets.")
# --- Initialize LLM (Using Groq) ---
try:
llm = ChatGroq(
model="meta-llama/llama-4-maverick-17b-128e-instruct", # Powerful model available on Groq, good for reasoning
# model="gemma2-9b-it", # Alternative lighter model
api_key=groq_api_key,
temperature=0.3 # Low temperature for factual tasks
)
print(f"LLM Initialized: Groq - {llm.model_name}")
except Exception as e:
print(f"ERROR initializing Groq LLM: {e}")
traceback.print_exc()
raise # Stop if LLM fails to init
# ==============================================================================
# State Definition
# ==============================================================================
class AgentState(TypedDict):
"""Defines the structure of the information the agent tracks during its run."""
input_question: str # The original question from the benchmark
messages: Annotated[List[BaseMessage], add_messages] # History of interactions (Human, AI, Tool)
error: Optional[str] # Stores any error message encountered
iterations: int # Counter for agent steps to prevent loops
# ==============================================================================
# Tools Definitions
# ==============================================================================
print("Defining tools...")
# --- Search Tool (Tavily) ---
search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key)
search_tool.name = "web_search"
search_tool.description = "Performs a web search (using Tavily) to find relevant URLs/snippets for a query."
# --- Web Browser Tool (html2text) ---
@tool
def web_browser(url: str) -> str:
"""Fetches text content from a webpage URL using html2text. Use after 'web_search'."""
print(f"--- [Tool] Browsing (html2text): {url} ---")
try:
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}
response = requests.get(url, headers=headers, timeout=20)
response.raise_for_status()
response.encoding = response.apparent_encoding or 'utf-8'
# Configure html2text
h = html2text.HTML2Text(bodywidth=0)
h.ignore_links = True
h.ignore_images = True
# Convert HTML to text
clean_text = h.handle(response.text)
# Limit content length
max_length = 6000
if len(clean_text) > max_length:
return clean_text[:max_length] + "\n\n... [Content Truncated]"
cleaned_and_stripped = clean_text.strip()
return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}."
except requests.exceptions.RequestException as e:
return f"Error: Network request failed for URL: {url}. Reason: {e}"
except Exception as e:
return f"Error: Unexpected error processing URL with html2text: {url}. Reason: {str(e)}"
# --- File Download Tool ---
@tool
def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
"""Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error."""
print(f"--- [Tool] Downloading file from: {url} ---")
try:
# Generate filename if needed
if not filename:
try: path = urlparse(url).path; filename = os.path.basename(path) if path else None
except Exception: filename = None
if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}"
# Define save path
temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename)
# Download file
response = requests.get(url, stream=True, timeout=30); response.raise_for_status()
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
print(f"--- [Tool] File downloaded to: {filepath} ---")
return f"File downloaded to {filepath}. Use appropriate tools (e.g., analyze_csv_file) to process it."
except requests.exceptions.RequestException as e:
return f"Error downloading file: Network issue for {url}. Reason: {e}"
except Exception as e:
return f"Error downloading file: Unexpected error for {url}. Reason: {str(e)}"
# --- CSV Analysis Tool ---
@tool
def analyze_csv_file(file_path: str) -> str:
"""Analyzes a CSV file at the given path using pandas. Returns a summary of content or error."""
print(f"--- [Tool] Analyzing CSV: {file_path} ---")
# GAIA might provide relative paths, ensure they work or adjust logic if needed
if not os.path.exists(file_path): return f"Error: CSV file not found at path: {file_path}"
try:
df = pd.read_csv(file_path)
# Generate summary string
summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n"
summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
summary += f"- Columns: {', '.join(df.columns)}\n"
summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n"
numeric_cols = df.select_dtypes(include=['number'])
if not numeric_cols.empty:
summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
else:
summary += "\nNo numeric columns for stats."
return summary
except ImportError: return "Error: 'pandas' required but not installed."
except Exception as e: return f"Error analyzing CSV {file_path}: {str(e)}"
# --- Excel Analysis Tool ---
@tool
def analyze_excel_file(file_path: str) -> str:
"""Analyzes an Excel file (.xlsx, .xls) at the given path. Returns a summary of the first sheet or error."""
print(f"--- [Tool] Analyzing Excel: {file_path} ---")
if not os.path.exists(file_path): return f"Error: Excel file not found at path: {file_path}"
try:
df = pd.read_excel(file_path, engine='openpyxl')
# Generate summary string
summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n"
summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
summary += f"- Columns: {', '.join(df.columns)}\n"
summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n"
numeric_cols = df.select_dtypes(include=['number'])
if not numeric_cols.empty:
summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
else:
summary += "\nNo numeric columns for stats."
return summary
except ImportError: return "Error: 'pandas' and 'openpyxl' required but not installed."
except Exception as e: return f"Error analyzing Excel {file_path}: {str(e)}"
# --- Image Text Extraction Tool (OCR) ---
@tool
def extract_text_from_image(file_path: str) -> str:
"""Extracts text from an image file at the given path using Tesseract OCR. Returns extracted text or error."""
print(f"--- [Tool] Extracting text from image: {file_path} ---")
if not os.path.exists(file_path): return f"Error: Image file not found at path: {file_path}"
try:
# Need to explicitly handle potential empty string from pytesseract
text = pytesseract.image_to_string(Image.open(file_path))
text_stripped = text.strip()
# Return a clear message if no text found, otherwise return extracted text
return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image."
except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed."
except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH."
except Exception as e: return f"Error extracting text from image {file_path}: {str(e)}"
# --- Basic Math Tools ---
@tool
def add(a: float, b: float) -> float:
"""Adds two numbers (a + b). Handles float inputs."""
print(f"--- [Tool] Calculating: {a} + {b} ---")
return a + b
@tool
def subtract(a: float, b: float) -> float:
"""Subtracts the second number from the first (a - b). Handles float inputs."""
print(f"--- [Tool] Calculating: {a} - {b} ---")
return a - b
@tool
def multiply(a: float, b: float) -> float:
"""Multiplies two numbers (a * b). Handles float inputs."""
print(f"--- [Tool] Calculating: {a} * {b} ---")
return a * b
@tool
def divide(a: float, b: float) -> float | str:
"""Divides the first number by the second (a / b). Handles float inputs and division by zero."""
print(f"--- [Tool] Calculating: {a} / {b} ---")
if b == 0: return "Error: Cannot divide by zero."
return a / b
# --- Compile list of all tools ---
tools = [ search_tool, web_browser, download_file_from_url, analyze_csv_file,
analyze_excel_file, extract_text_from_image, add, subtract, multiply, divide ]
# --- Bind tools to the LLM ---
# Ensure LLM is initialized before binding
if 'llm' not in globals():
raise RuntimeError("LLM was not initialized successfully before binding tools.")
llm_with_tools = llm.bind_tools(tools)
print(f"Agent initialized with {len(tools)} tools.")
# ==============================================================================
# Node Definitions (With Logging Added)
# ==============================================================================
print("Defining graph nodes...")
# --- Agent Node ---
def call_agent_node(state: AgentState) -> dict:
"""Invokes the LLM with current state to decide the next step."""
# --- Logging: Node Entry ---
print(f"\n>>> Entering Agent Node (Iteration {state['iterations']})")
MAX_ITERATIONS = 15 # Max steps allowed for the task - Increased slightly
current_iterations = state.get('iterations', 0)
if current_iterations >= MAX_ITERATIONS:
print(f"!!! Agent Node: Max iterations ({MAX_ITERATIONS}) reached. Setting error.")
return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."}
try:
print(f"--- Agent Node: Invoking LLM ({llm.model_name})... ---") # Log before LLM call
# Ensure LLM is bound with tools before invoking
if 'llm_with_tools' not in globals():
return {"error": "LLM tools not bound."}
response = llm_with_tools.invoke(state['messages'])
print(f"--- Agent Node: LLM Invocation Complete. ---") # Log after LLM call
# response.pretty_print() # Optional: Print formatted LLM response
# --- Logging: Node Exit (Success) ---
print(f"<<< Exiting Agent Node (Success, Iteration {current_iterations + 1})")
return {"messages": [response], "iterations": current_iterations + 1}
except Exception as e:
error_message = f"LLM invocation failed: {str(e)}"
print(f"!!! Agent Node ERROR: {error_message} !!!")
traceback.print_exc() # Print full traceback for debugging LLM errors
# --- Logging: Node Exit (Error) ---
print(f"<<< Exiting Agent Node (LLM Error, Iteration {current_iterations})")
# Return an error message and set error state, still increment iteration to prevent infinite error loops
return {"messages": [AIMessage(content=f"Error during LLM call: {error_message}")], "error": error_message, "iterations": current_iterations + 1}
# --- Tool Node Wrapper (for Logging) ---
# We still use the prebuilt ToolNode, but wrap its call for logging
tool_executor = ToolNode(tools) # Keep the instance
def logged_tool_node(state: AgentState) -> dict:
"""Logs tool execution start/end and calls the actual ToolNode."""
print(f">>> Entering Tool Node")
# Log requested tools
last_message = state['messages'][-1]
requested_tools_str = "None"
tool_calls = []
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
tool_calls = last_message.tool_calls
tool_names = [tc.get('name', 'unknown') for tc in tool_calls]
requested_tools_str = ", ".join(tool_names)
print(f"--- Tool Node: Executing tools: {requested_tools_str} ---")
if tool_calls: print(f"--- Tool Node: Tool Args: {[tc.get('args') for tc in tool_calls]} ---")
try:
# Call the actual ToolNode instance
result = tool_executor.invoke(state)
# Log truncated results
print("--- Tool Node: Tool Execution Results ---")
if isinstance(result.get("messages"), list):
for msg in result["messages"]:
if isinstance(msg, ToolMessage):
print(f" - Tool: {msg.name}, Result (start): {str(msg.content)[:200]}...") # Slightly more context
print(f"<<< Exiting Tool Node (Success)")
return result # Return the dictionary containing ToolMessages
except Exception as e:
error_message = f"ToolNode invocation exception: {str(e)}"
print(f"!!! Tool Node ERROR: {error_message} !!!")
traceback.print_exc()
print(f"<<< Exiting Tool Node (Error)")
# Return an error message in the expected format
return {"messages": [ToolMessage(content=error_message, tool_call_id="tool_node_error")]}
# ==============================================================================
# Graph Construction (Non-conversational, using logged tool node)
# ==============================================================================
print("Building agent graph...")
builder = StateGraph(AgentState)
builder.add_node("agent", call_agent_node)
builder.add_node("tools", logged_tool_node) # Use the logging wrapper node
builder.add_edge(START, "agent")
builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", END: END})
builder.add_edge("tools", "agent")
# Compile the graph globally so it's ready for the function call
try:
graph = builder.compile()
print("GAIA agent graph compiled successfully.")
except Exception as e:
print(f"ERROR: Failed to compile LangGraph graph: {e}")
traceback.print_exc()
graph = None # Ensure graph is None if compilation fails
raise # Reraise exception to make startup failure clear
# ==============================================================================
# Main Execution Function for GAIA Benchmark <<<< WRAPPER FUNCTION >>>>
# ==============================================================================
def answer_gaia_task(question: str, file_path: Optional[str] = None) -> str:
"""
Runs the compiled GAIA agent graph for a given question and optional file path.
This is the main entry point expected by the benchmark runner.
"""
# Check if graph compilation was successful
if graph is None:
return "Error: Agent graph was not compiled successfully during setup."
print(f"\n{'='*20} Running Agent for GAIA Task {'='*20}")
print(f"Question: {question}")
file_context_info = f"An associated file is provided at path: '{file_path}'. Use this path if relevant." if file_path else ""
# Define the initial prompt sent to the agent, incorporating strict formatting rules
prompt_content = f"""Your task is to accurately answer the following question based *only* on information obtained using your tools (web search, web browser, file download, csv/excel analysis, image OCR, math).
{file_context_info}
Follow these steps methodically:
1. Analyze the question: {question}
2. Use tools ONLY if necessary to gather the specific information required. Assume local file paths mentioned (like 'data.csv') are accessible.
3. Synthesize the final answer from the gathered information.
**CRITICAL OUTPUT FORMATTING RULES:**
* Your final response MUST be ONLY the answer, without any other text/explanations.
* **Numbers:** No commas (1000). No units ($ , %) unless asked.
* **Strings:** No articles (a, an, the) unless proper noun. No abbreviations (Saint Petersburg) unless answer is abbreviation. Use numerals (5).
* **Lists:** Comma-separated (apple,banana,cherry). Apply number/string rules to elements.
* If answer not found, output only the exact phrase: Information not found
Provide ONLY the final answer according to these rules.
"""
# Create the initial state for the graph run
initial_state = AgentState(
input_question=question,
messages=[HumanMessage(content=prompt_content)],
error=None,
iterations=0
)
final_answer = "Error: Agent execution did not complete successfully." # Default fallback
try:
# Invoke the compiled graph
final_state = graph.invoke(initial_state, config={"recursion_limit": 20}) # Increased recursion limit
# Process the final state to extract the answer
if final_state:
# Prioritize showing agent error if one occurred
if final_state.get("error"):
print(f"--- Agent stopped due to ERROR: {final_state['error']} ---")
final_answer = f"Error: {final_state['error']}"
# Otherwise, try to get the last AI message content
elif final_state.get('messages') and isinstance(final_state['messages'][-1], AIMessage):
potential_answer = final_state['messages'][-1].content
# Basic cleanup for potential quotes added by LLM
if isinstance(potential_answer, str):
if (potential_answer.startswith('"') and potential_answer.endswith('"')) or \
(potential_answer.startswith("'") and potential_answer.endswith("'")):
potential_answer = potential_answer[1:-1].strip()
print(f"--- Final Answer (from AI): {potential_answer} ---")
final_answer = potential_answer
else:
print("--- Could not determine final answer (last message not AI or missing). Check logs. ---")
# Log final state details for debugging
print(f"Final State: Error={final_state.get('error')}, Iterations={final_state.get('iterations')}")
except Exception as e:
print(f"--- Graph execution failed unexpectedly: {e} ---")
traceback.print_exc()
final_answer = f"Error: Graph execution failed - {str(e)}"
print(f"{'='*20} Agent Run Finished {'='*20}")
# Return the final answer string
return str(final_answer)
# ==============================================================================
# Local Testing Block (Optional)
# ==============================================================================
# This block allows you to test the agent by running final_agent.py directly.
if __name__ == "__main__":
print("\n--- Running Local Test ---")
# --- Define Test Question ---
test_question = "What is the result of multiplying the number of rows (excluding the header) in 'data.csv' by the number found after the phrase 'total items:' in 'image.png'?"
# --- Create Dummy Files for Local Test ---
print("Creating dummy files for local test...")
dummy_files_created = True
try:
# Dummy CSV with 3 data rows + header
with open("data.csv", "w") as f:
f.write("Header1,Header2\nRow1Val1,Row1Val2\nRow2Val1,Row2Val2\nRow3Val1,Row3Val2")
# Dummy Image containing the required text
try:
img = Image.new('RGB', (300, 50), color = (255, 255, 255)) # White background
from PIL import ImageDraw, ImageFont # Import drawing tools locally
draw = ImageDraw.Draw(img)
# Use a basic font if specific ones aren't found
try: font = ImageFont.truetype("arial.ttf", 15)
except IOError: font = ImageFont.load_default()
draw.text((10,10), "Some random info... total items: 7 ... more text", fill=(0,0,0), font=font) # Black text
img.save("image.png")
print("Dummy data.csv and image.png created successfully.")
except ImportError:
print("Pillow/ImageDraw/ImageFont not installed. Cannot create dummy image file.")
dummy_files_created = False
except Exception as img_e:
print(f"Error creating dummy image: {img_e}")
dummy_files_created = False
except Exception as file_e:
print(f"Error creating dummy files: {file_e}")
dummy_files_created = False
# ---------------------------------------------
# --- Run the Test ---
if dummy_files_created:
# Call the main function, simulating how the benchmark runner would call it.
result = answer_gaia_task(question=test_question, file_path=None)
print(f"\n--- Local Test Result ---")
print(f"Returned Answer: {result}")
print(f"Expected Answer (for dummy files): 21") # 3 data rows * 7 = 21
else:
print("Skipping test execution due to issues creating dummy files.")
# --- Clean up Dummy Files ---
print("\nCleaning up dummy files...")
for dummy_file in ["data.csv", "image.png"]:
if os.path.exists(dummy_file):
try: os.remove(dummy_file)
except Exception as e: print(f"Could not remove {dummy_file}: {e}")
print("Dummy file cleanup attempted.")