Spaces:
Sleeping
Sleeping
Update final_agent.py
Browse files- final_agent.py +202 -90
final_agent.py
CHANGED
|
@@ -14,6 +14,7 @@ from urllib.parse import urlparse # For download tool
|
|
| 14 |
from typing import Annotated, List, TypedDict, Optional
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
|
|
|
|
| 17 |
from langgraph.graph import StateGraph, START, END
|
| 18 |
from langgraph.graph.message import add_messages
|
| 19 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
@@ -29,19 +30,27 @@ load_dotenv()
|
|
| 29 |
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 30 |
tavily_api_key = os.getenv("TAVILY_API_KEY")
|
| 31 |
|
| 32 |
-
# --- Optional: Tesseract Path
|
| 33 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
|
|
|
|
|
|
| 35 |
if not gemini_api_key:
|
| 36 |
raise ValueError("GEMINI_API_KEY not found in environment variables.")
|
| 37 |
if not tavily_api_key:
|
| 38 |
-
raise ValueError("TAVILY_API_KEY not found. Required for search.")
|
| 39 |
|
| 40 |
-
#
|
|
|
|
| 41 |
llm = ChatGoogleGenerativeAI(
|
| 42 |
-
model="gemini-
|
| 43 |
google_api_key=gemini_api_key,
|
| 44 |
-
temperature=0.1 #
|
| 45 |
)
|
| 46 |
print(f"LLM Initialized: {llm.model}")
|
| 47 |
|
|
@@ -49,15 +58,16 @@ print(f"LLM Initialized: {llm.model}")
|
|
| 49 |
# State Definition
|
| 50 |
# ==============================================================================
|
| 51 |
class AgentState(TypedDict):
|
| 52 |
-
"""
|
| 53 |
-
input_question: str #
|
| 54 |
-
messages: Annotated[List[BaseMessage], add_messages]
|
| 55 |
-
error: Optional[str]
|
| 56 |
-
iterations: int
|
| 57 |
|
| 58 |
# ==============================================================================
|
| 59 |
-
# Tools
|
| 60 |
# ==============================================================================
|
|
|
|
| 61 |
|
| 62 |
# --- Search Tool (Tavily) ---
|
| 63 |
search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key)
|
|
@@ -70,16 +80,20 @@ def web_browser(url: str) -> str:
|
|
| 70 |
"""Fetches text content from a webpage URL using html2text. Use after 'web_search'."""
|
| 71 |
print(f"--- [Tool] Browsing (html2text): {url} ---")
|
| 72 |
try:
|
| 73 |
-
headers = {'User-Agent': 'Mozilla/5.0'}
|
| 74 |
response = requests.get(url, headers=headers, timeout=20)
|
| 75 |
response.raise_for_status()
|
| 76 |
response.encoding = response.apparent_encoding or 'utf-8'
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
clean_text = h.handle(response.text)
|
|
|
|
| 79 |
max_length = 6000
|
| 80 |
if len(clean_text) > max_length:
|
| 81 |
return clean_text[:max_length] + "\n\n... [Content Truncated]"
|
| 82 |
-
# Ensure we return error string if empty after strip
|
| 83 |
cleaned_and_stripped = clean_text.strip()
|
| 84 |
return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}."
|
| 85 |
except requests.exceptions.RequestException as e:
|
|
@@ -93,11 +107,14 @@ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
|
|
| 93 |
"""Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error."""
|
| 94 |
print(f"--- [Tool] Downloading file from: {url} ---")
|
| 95 |
try:
|
|
|
|
| 96 |
if not filename:
|
| 97 |
try: path = urlparse(url).path; filename = os.path.basename(path) if path else None
|
| 98 |
except Exception: filename = None
|
| 99 |
if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}"
|
|
|
|
| 100 |
temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename)
|
|
|
|
| 101 |
response = requests.get(url, stream=True, timeout=30); response.raise_for_status()
|
| 102 |
with open(filepath, 'wb') as f:
|
| 103 |
for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
|
|
@@ -113,9 +130,11 @@ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
|
|
| 113 |
def analyze_csv_file(file_path: str) -> str:
|
| 114 |
"""Analyzes a CSV file at the given path using pandas. Returns a summary of content or error."""
|
| 115 |
print(f"--- [Tool] Analyzing CSV: {file_path} ---")
|
| 116 |
-
|
|
|
|
| 117 |
try:
|
| 118 |
df = pd.read_csv(file_path)
|
|
|
|
| 119 |
summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n"
|
| 120 |
summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
| 121 |
summary += f"- Columns: {', '.join(df.columns)}\n"
|
|
@@ -134,9 +153,10 @@ def analyze_csv_file(file_path: str) -> str:
|
|
| 134 |
def analyze_excel_file(file_path: str) -> str:
|
| 135 |
"""Analyzes an Excel file (.xlsx, .xls) at the given path. Returns a summary of the first sheet or error."""
|
| 136 |
print(f"--- [Tool] Analyzing Excel: {file_path} ---")
|
| 137 |
-
if not os.path.exists(file_path): return f"Error: Excel file not found: {file_path}"
|
| 138 |
try:
|
| 139 |
df = pd.read_excel(file_path, engine='openpyxl')
|
|
|
|
| 140 |
summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n"
|
| 141 |
summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
| 142 |
summary += f"- Columns: {', '.join(df.columns)}\n"
|
|
@@ -155,11 +175,12 @@ def analyze_excel_file(file_path: str) -> str:
|
|
| 155 |
def extract_text_from_image(file_path: str) -> str:
|
| 156 |
"""Extracts text from an image file at the given path using Tesseract OCR. Returns extracted text or error."""
|
| 157 |
print(f"--- [Tool] Extracting text from image: {file_path} ---")
|
| 158 |
-
if not os.path.exists(file_path): return f"Error: Image file not found: {file_path}"
|
| 159 |
try:
|
| 160 |
# Need to explicitly handle potential empty string from pytesseract
|
| 161 |
text = pytesseract.image_to_string(Image.open(file_path))
|
| 162 |
text_stripped = text.strip()
|
|
|
|
| 163 |
return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image."
|
| 164 |
except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed."
|
| 165 |
except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH."
|
|
@@ -188,133 +209,224 @@ def divide(a: float, b: float) -> float | str:
|
|
| 188 |
if b == 0: return "Error: Cannot divide by zero."
|
| 189 |
return a / b
|
| 190 |
|
| 191 |
-
# ---
|
| 192 |
tools = [ search_tool, web_browser, download_file_from_url, analyze_csv_file,
|
| 193 |
analyze_excel_file, extract_text_from_image, add, subtract, multiply, divide ]
|
|
|
|
|
|
|
| 194 |
llm_with_tools = llm.bind_tools(tools)
|
| 195 |
print(f"Agent initialized with {len(tools)} tools.")
|
| 196 |
|
| 197 |
# ==============================================================================
|
| 198 |
# Node Definitions
|
| 199 |
# ==============================================================================
|
|
|
|
| 200 |
|
| 201 |
-
# --- Agent Node
|
| 202 |
def call_agent_node(state: AgentState) -> dict:
|
| 203 |
-
"""
|
| 204 |
print(f"\n--- [Node] Agent thinking... (Iteration {state['iterations']}) ---")
|
| 205 |
-
MAX_ITERATIONS = 10 # Max steps for the
|
| 206 |
current_iterations = state.get('iterations', 0)
|
| 207 |
if current_iterations >= MAX_ITERATIONS:
|
| 208 |
print(f"Warning: Reached max iterations ({MAX_ITERATIONS}). Stopping.")
|
| 209 |
-
# Return error message in state
|
| 210 |
return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."}
|
| 211 |
try:
|
|
|
|
| 212 |
response = llm_with_tools.invoke(state['messages'])
|
| 213 |
print("--- [Node] AI Response/Action ---")
|
| 214 |
-
response.pretty_print()
|
| 215 |
-
#
|
| 216 |
return {"messages": [response], "iterations": current_iterations + 1}
|
| 217 |
except Exception as e:
|
| 218 |
error_message = f"LLM invocation failed: {str(e)}"
|
| 219 |
print(f"--- [Node] ERROR: {error_message} ---")
|
| 220 |
-
#
|
| 221 |
-
|
|
|
|
| 222 |
|
| 223 |
-
# --- Tool Node
|
|
|
|
| 224 |
tool_node = ToolNode(tools)
|
| 225 |
|
| 226 |
# ==============================================================================
|
| 227 |
# Graph Construction (Non-conversational)
|
| 228 |
# ==============================================================================
|
|
|
|
| 229 |
builder = StateGraph(AgentState)
|
| 230 |
|
| 231 |
-
# Add nodes
|
| 232 |
builder.add_node("agent", call_agent_node)
|
| 233 |
builder.add_node("tools", tool_node)
|
| 234 |
|
| 235 |
-
#
|
| 236 |
builder.add_edge(START, "agent")
|
| 237 |
|
| 238 |
-
#
|
| 239 |
builder.add_conditional_edges(
|
| 240 |
"agent",
|
| 241 |
-
tools_condition, #
|
| 242 |
{
|
| 243 |
-
"tools": "tools", # If
|
| 244 |
-
END: END # If no
|
| 245 |
}
|
| 246 |
)
|
| 247 |
|
| 248 |
-
#
|
| 249 |
-
builder.add_edge("tools", "agent") #
|
| 250 |
|
| 251 |
-
# Compile the graph
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
# ==============================================================================
|
| 256 |
-
# Execution
|
| 257 |
# ==============================================================================
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
Follow these steps methodically:
|
| 273 |
1. Analyze the question to understand required information and tools needed.
|
| 274 |
-
2. If external files are mentioned (e.g.,
|
| 275 |
-
3. If a URL is given for a file, use 'download_file_from_url' first, then analyze the downloaded file using its path.
|
| 276 |
4. If web information is needed, use 'web_search' then 'web_browser' on relevant URLs.
|
| 277 |
5. If calculations are needed, use the math tools.
|
| 278 |
6. Synthesize the information gathered from tools to arrive at the final answer.
|
| 279 |
7. **CRITICAL:** Your final output MUST contain ONLY the precise numerical or text answer requested by the question. Do NOT include explanations, reasoning steps, units unless explicitly asked for, context, apologies, or any introductory phrases like "The final answer is...". Just the required answer string or number itself.
|
| 280 |
|
| 281 |
-
Question: {
|
| 282 |
-
"""
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
try:
|
| 288 |
-
# Run the graph from start to end for the single task
|
| 289 |
-
final_state = graph.invoke(initial_state, config={"recursion_limit": 15})
|
| 290 |
-
except Exception as e:
|
| 291 |
-
print(f"--- Graph execution failed unexpectedly: {e} ---")
|
| 292 |
-
traceback.print_exc()
|
| 293 |
-
final_state = None
|
| 294 |
|
| 295 |
# ==============================================================================
|
| 296 |
-
#
|
| 297 |
# ==============================================================================
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
else:
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
else:
|
| 320 |
-
print("Execution failed, no final state.")
|
|
|
|
| 14 |
from typing import Annotated, List, TypedDict, Optional
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
|
| 17 |
+
# LangChain and LangGraph Imports
|
| 18 |
from langgraph.graph import StateGraph, START, END
|
| 19 |
from langgraph.graph.message import add_messages
|
| 20 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
|
|
| 30 |
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 31 |
tavily_api_key = os.getenv("TAVILY_API_KEY")
|
| 32 |
|
| 33 |
+
# --- Optional: Tesseract Path ---
|
| 34 |
+
# If Tesseract OCR is not in your system's PATH environment variable,
|
| 35 |
+
# uncomment the following line and set the correct path to tesseract.exe
|
| 36 |
+
# try:
|
| 37 |
+
# pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Example path for Windows
|
| 38 |
+
# except NameError: pass # Handles case where pytesseract might not be imported yet if PIL fails first
|
| 39 |
+
# except Exception as e: print(f"Warning: Could not set tesseract_cmd path: {e}")
|
| 40 |
|
| 41 |
+
|
| 42 |
+
# --- Validate API Keys ---
|
| 43 |
if not gemini_api_key:
|
| 44 |
raise ValueError("GEMINI_API_KEY not found in environment variables.")
|
| 45 |
if not tavily_api_key:
|
| 46 |
+
raise ValueError("TAVILY_API_KEY not found. Required for Tavily search tool.")
|
| 47 |
|
| 48 |
+
# --- Initialize LLM ---
|
| 49 |
+
# Using the model specified in the user's code block
|
| 50 |
llm = ChatGoogleGenerativeAI(
|
| 51 |
+
model="gemini-1.5-flash-latest", # As per user's last provided code
|
| 52 |
google_api_key=gemini_api_key,
|
| 53 |
+
temperature=0.1 # Low temperature for factual tasks
|
| 54 |
)
|
| 55 |
print(f"LLM Initialized: {llm.model}")
|
| 56 |
|
|
|
|
| 58 |
# State Definition
|
| 59 |
# ==============================================================================
|
| 60 |
class AgentState(TypedDict):
|
| 61 |
+
"""Defines the structure of the information the agent tracks during its run."""
|
| 62 |
+
input_question: str # The original question from the benchmark
|
| 63 |
+
messages: Annotated[List[BaseMessage], add_messages] # History of interactions (Human, AI, Tool)
|
| 64 |
+
error: Optional[str] # Stores any error message encountered
|
| 65 |
+
iterations: int # Counter for agent steps to prevent loops
|
| 66 |
|
| 67 |
# ==============================================================================
|
| 68 |
+
# Tools Definitions
|
| 69 |
# ==============================================================================
|
| 70 |
+
print("Defining tools...")
|
| 71 |
|
| 72 |
# --- Search Tool (Tavily) ---
|
| 73 |
search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key)
|
|
|
|
| 80 |
"""Fetches text content from a webpage URL using html2text. Use after 'web_search'."""
|
| 81 |
print(f"--- [Tool] Browsing (html2text): {url} ---")
|
| 82 |
try:
|
| 83 |
+
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}
|
| 84 |
response = requests.get(url, headers=headers, timeout=20)
|
| 85 |
response.raise_for_status()
|
| 86 |
response.encoding = response.apparent_encoding or 'utf-8'
|
| 87 |
+
# Configure html2text
|
| 88 |
+
h = html2text.HTML2Text(bodywidth=0)
|
| 89 |
+
h.ignore_links = True
|
| 90 |
+
h.ignore_images = True
|
| 91 |
+
# Convert HTML to text
|
| 92 |
clean_text = h.handle(response.text)
|
| 93 |
+
# Limit content length
|
| 94 |
max_length = 6000
|
| 95 |
if len(clean_text) > max_length:
|
| 96 |
return clean_text[:max_length] + "\n\n... [Content Truncated]"
|
|
|
|
| 97 |
cleaned_and_stripped = clean_text.strip()
|
| 98 |
return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}."
|
| 99 |
except requests.exceptions.RequestException as e:
|
|
|
|
| 107 |
"""Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error."""
|
| 108 |
print(f"--- [Tool] Downloading file from: {url} ---")
|
| 109 |
try:
|
| 110 |
+
# Generate filename if needed
|
| 111 |
if not filename:
|
| 112 |
try: path = urlparse(url).path; filename = os.path.basename(path) if path else None
|
| 113 |
except Exception: filename = None
|
| 114 |
if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}"
|
| 115 |
+
# Define save path
|
| 116 |
temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename)
|
| 117 |
+
# Download file
|
| 118 |
response = requests.get(url, stream=True, timeout=30); response.raise_for_status()
|
| 119 |
with open(filepath, 'wb') as f:
|
| 120 |
for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
|
|
|
|
| 130 |
def analyze_csv_file(file_path: str) -> str:
|
| 131 |
"""Analyzes a CSV file at the given path using pandas. Returns a summary of content or error."""
|
| 132 |
print(f"--- [Tool] Analyzing CSV: {file_path} ---")
|
| 133 |
+
# GAIA might provide relative paths, ensure they work or adjust logic if needed
|
| 134 |
+
if not os.path.exists(file_path): return f"Error: CSV file not found at path: {file_path}"
|
| 135 |
try:
|
| 136 |
df = pd.read_csv(file_path)
|
| 137 |
+
# Generate summary string
|
| 138 |
summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n"
|
| 139 |
summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
| 140 |
summary += f"- Columns: {', '.join(df.columns)}\n"
|
|
|
|
| 153 |
def analyze_excel_file(file_path: str) -> str:
|
| 154 |
"""Analyzes an Excel file (.xlsx, .xls) at the given path. Returns a summary of the first sheet or error."""
|
| 155 |
print(f"--- [Tool] Analyzing Excel: {file_path} ---")
|
| 156 |
+
if not os.path.exists(file_path): return f"Error: Excel file not found at path: {file_path}"
|
| 157 |
try:
|
| 158 |
df = pd.read_excel(file_path, engine='openpyxl')
|
| 159 |
+
# Generate summary string
|
| 160 |
summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n"
|
| 161 |
summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
| 162 |
summary += f"- Columns: {', '.join(df.columns)}\n"
|
|
|
|
| 175 |
def extract_text_from_image(file_path: str) -> str:
|
| 176 |
"""Extracts text from an image file at the given path using Tesseract OCR. Returns extracted text or error."""
|
| 177 |
print(f"--- [Tool] Extracting text from image: {file_path} ---")
|
| 178 |
+
if not os.path.exists(file_path): return f"Error: Image file not found at path: {file_path}"
|
| 179 |
try:
|
| 180 |
# Need to explicitly handle potential empty string from pytesseract
|
| 181 |
text = pytesseract.image_to_string(Image.open(file_path))
|
| 182 |
text_stripped = text.strip()
|
| 183 |
+
# Return a clear message if no text found, otherwise return extracted text
|
| 184 |
return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image."
|
| 185 |
except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed."
|
| 186 |
except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH."
|
|
|
|
| 209 |
if b == 0: return "Error: Cannot divide by zero."
|
| 210 |
return a / b
|
| 211 |
|
| 212 |
+
# --- Compile list of all tools ---
|
| 213 |
tools = [ search_tool, web_browser, download_file_from_url, analyze_csv_file,
|
| 214 |
analyze_excel_file, extract_text_from_image, add, subtract, multiply, divide ]
|
| 215 |
+
|
| 216 |
+
# --- Bind tools to the LLM ---
|
| 217 |
llm_with_tools = llm.bind_tools(tools)
|
| 218 |
print(f"Agent initialized with {len(tools)} tools.")
|
| 219 |
|
| 220 |
# ==============================================================================
|
| 221 |
# Node Definitions
|
| 222 |
# ==============================================================================
|
| 223 |
+
print("Defining graph nodes...")
|
| 224 |
|
| 225 |
+
# --- Agent Node ---
|
| 226 |
def call_agent_node(state: AgentState) -> dict:
|
| 227 |
+
"""Invokes the LLM with current state to decide the next step."""
|
| 228 |
print(f"\n--- [Node] Agent thinking... (Iteration {state['iterations']}) ---")
|
| 229 |
+
MAX_ITERATIONS = 10 # Max steps allowed for the task
|
| 230 |
current_iterations = state.get('iterations', 0)
|
| 231 |
if current_iterations >= MAX_ITERATIONS:
|
| 232 |
print(f"Warning: Reached max iterations ({MAX_ITERATIONS}). Stopping.")
|
|
|
|
| 233 |
return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."}
|
| 234 |
try:
|
| 235 |
+
# Call the LLM
|
| 236 |
response = llm_with_tools.invoke(state['messages'])
|
| 237 |
print("--- [Node] AI Response/Action ---")
|
| 238 |
+
response.pretty_print() # Log the LLM's thoughts and actions
|
| 239 |
+
# Return the response message and incremented iteration count
|
| 240 |
return {"messages": [response], "iterations": current_iterations + 1}
|
| 241 |
except Exception as e:
|
| 242 |
error_message = f"LLM invocation failed: {str(e)}"
|
| 243 |
print(f"--- [Node] ERROR: {error_message} ---")
|
| 244 |
+
traceback.print_exc() # Print full traceback for debugging LLM errors
|
| 245 |
+
# Return an error message and set error state
|
| 246 |
+
return {"messages": [AIMessage(content=f"Sorry, I encountered an error: {error_message}")], "error": error_message, "iterations": current_iterations + 1}
|
| 247 |
|
| 248 |
+
# --- Tool Node ---
|
| 249 |
+
# Use the prebuilt ToolNode to handle execution of the bound tools
|
| 250 |
tool_node = ToolNode(tools)
|
| 251 |
|
| 252 |
# ==============================================================================
|
| 253 |
# Graph Construction (Non-conversational)
|
| 254 |
# ==============================================================================
|
| 255 |
+
print("Building agent graph...")
|
| 256 |
builder = StateGraph(AgentState)
|
| 257 |
|
| 258 |
+
# Add the agent and tool nodes
|
| 259 |
builder.add_node("agent", call_agent_node)
|
| 260 |
builder.add_node("tools", tool_node)
|
| 261 |
|
| 262 |
+
# Set the entry point
|
| 263 |
builder.add_edge(START, "agent")
|
| 264 |
|
| 265 |
+
# Define the conditional logic after the agent node runs
|
| 266 |
builder.add_conditional_edges(
|
| 267 |
"agent",
|
| 268 |
+
tools_condition, # Built-in function checks if the last message has tool_calls
|
| 269 |
{
|
| 270 |
+
"tools": "tools", # If tool calls exist, route to the tools node
|
| 271 |
+
END: END # If no tool calls, the agent is done, route to END
|
| 272 |
}
|
| 273 |
)
|
| 274 |
|
| 275 |
+
# Define the edge after the tools node runs
|
| 276 |
+
builder.add_edge("tools", "agent") # Always return to the agent node to process tool results
|
| 277 |
|
| 278 |
+
# Compile the graph into a runnable object
|
| 279 |
+
# NOTE: This compilation happens when the script is imported by app.py
|
| 280 |
+
try:
|
| 281 |
+
graph = builder.compile()
|
| 282 |
+
print("GAIA agent graph compiled successfully.")
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"ERROR: Failed to compile LangGraph graph: {e}")
|
| 285 |
+
traceback.print_exc()
|
| 286 |
+
# Raise or handle appropriately - app might fail to start if graph doesn't compile
|
| 287 |
+
raise
|
| 288 |
|
| 289 |
# ==============================================================================
|
| 290 |
+
# Main Execution Function for GAIA Benchmark <<<< WRAPPER FUNCTION >>>>
|
| 291 |
# ==============================================================================
|
| 292 |
+
def answer_gaia_task(question: str, file_path: Optional[str] = None) -> str:
|
| 293 |
+
"""
|
| 294 |
+
Runs the compiled GAIA agent graph for a given question and optional file path.
|
| 295 |
+
This is the main entry point expected by the benchmark runner.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
question: The question text from the GAIA benchmark.
|
| 299 |
+
file_path: Optional path to a file associated with the question.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
A string containing the final answer extracted by the agent, or an error message.
|
| 303 |
+
"""
|
| 304 |
+
# Ensure the compiled graph is available
|
| 305 |
+
if 'graph' not in globals():
|
| 306 |
+
return "Error: Agent graph was not compiled successfully."
|
| 307 |
+
|
| 308 |
+
print(f"\n{'='*20} Running Agent for GAIA Task {'='*20}")
|
| 309 |
+
print(f"Question: {question}")
|
| 310 |
+
file_context_info = ""
|
| 311 |
+
if file_path:
|
| 312 |
+
print(f"Associated File Path: {file_path}")
|
| 313 |
+
file_context_info = f"An associated file is provided at path: '{file_path}'. Your tools should use this path if they require a file path not explicitly mentioned in the question."
|
| 314 |
+
|
| 315 |
+
# Define the initial prompt sent to the agent
|
| 316 |
+
prompt_content = f"""Your task is to accurately answer the following question based *only* on information obtained using your tools (web search, web browser, file download, csv/excel analysis, image OCR, math).
|
| 317 |
+
|
| 318 |
+
{file_context_info}
|
| 319 |
|
| 320 |
Follow these steps methodically:
|
| 321 |
1. Analyze the question to understand required information and tools needed.
|
| 322 |
+
2. If external files are mentioned (e.g., 'data.csv', 'image.png'), use the appropriate analysis tool directly on the provided file path/name. Assume files are accessible in the current directory unless a URL or the separate file path is given.
|
| 323 |
+
3. If a URL is given for a file, use 'download_file_from_url' first, then analyze the downloaded file using its returned path.
|
| 324 |
4. If web information is needed, use 'web_search' then 'web_browser' on relevant URLs.
|
| 325 |
5. If calculations are needed, use the math tools.
|
| 326 |
6. Synthesize the information gathered from tools to arrive at the final answer.
|
| 327 |
7. **CRITICAL:** Your final output MUST contain ONLY the precise numerical or text answer requested by the question. Do NOT include explanations, reasoning steps, units unless explicitly asked for, context, apologies, or any introductory phrases like "The final answer is...". Just the required answer string or number itself.
|
| 328 |
|
| 329 |
+
Question: {question}
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
# Create the initial state for the graph run
|
| 333 |
+
initial_state = AgentState(
|
| 334 |
+
input_question=question,
|
| 335 |
+
messages=[HumanMessage(content=prompt_content)],
|
| 336 |
+
error=None,
|
| 337 |
+
iterations=0
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
final_answer = "Error: Agent execution did not complete successfully." # Default fallback
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
# Invoke the compiled graph
|
| 344 |
+
final_state = graph.invoke(initial_state, config={"recursion_limit": 15}) # Set recursion limit
|
| 345 |
+
|
| 346 |
+
# Process the final state to extract the answer
|
| 347 |
+
if final_state:
|
| 348 |
+
if final_state.get("error"):
|
| 349 |
+
print(f"--- Agent stopped due to ERROR: {final_state['error']} ---")
|
| 350 |
+
final_answer = f"Error: {final_state['error']}"
|
| 351 |
+
# Check if the last message is an AIMessage and capture its content
|
| 352 |
+
elif final_state.get('messages') and isinstance(final_state['messages'][-1], AIMessage):
|
| 353 |
+
# Extract content from the last AI message - relies on prompt working
|
| 354 |
+
potential_answer = final_state['messages'][-1].content
|
| 355 |
+
print(f"--- Final Answer (from AI): {potential_answer} ---")
|
| 356 |
+
final_answer = potential_answer
|
| 357 |
+
else:
|
| 358 |
+
print("--- Could not determine final answer (last message not AI or missing). Check logs. ---")
|
| 359 |
+
# Log final state details for debugging
|
| 360 |
+
print(f"Final State: Error={final_state.get('error')}, Iterations={final_state.get('iterations')}")
|
| 361 |
+
|
| 362 |
+
except Exception as e:
|
| 363 |
+
print(f"--- Graph execution failed unexpectedly: {e} ---")
|
| 364 |
+
traceback.print_exc()
|
| 365 |
+
final_answer = f"Error: Graph execution failed - {str(e)}"
|
| 366 |
+
|
| 367 |
+
print(f"{'='*20} Agent Run Finished {'='*20}")
|
| 368 |
+
# Return the final answer string
|
| 369 |
+
return str(final_answer)
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
# ==============================================================================
|
| 373 |
+
# Local Testing Block (Optional)
|
| 374 |
# ==============================================================================
|
| 375 |
+
# This block allows you to test the agent by running final_agent.py directly.
|
| 376 |
+
# It will not run when the script is imported by app.py in the Space.
|
| 377 |
+
if __name__ == "__main__":
|
| 378 |
+
print("\n--- Running Local Test ---")
|
| 379 |
+
# --- Define Test Question ---
|
| 380 |
+
test_question = "What is the result of multiplying the number of rows (excluding the header) in 'data.csv' by the number found after the phrase 'total items:' in 'image.png'?"
|
| 381 |
+
|
| 382 |
+
# --- Create Dummy Files for Local Test ---
|
| 383 |
+
print("Creating dummy files for local test...")
|
| 384 |
+
dummy_files_created = True
|
| 385 |
+
try:
|
| 386 |
+
# Dummy CSV with 3 data rows + header
|
| 387 |
+
with open("data.csv", "w") as f:
|
| 388 |
+
f.write("Header1,Header2\nRow1Val1,Row1Val2\nRow2Val1,Row2Val2\nRow3Val1,Row3Val2")
|
| 389 |
+
|
| 390 |
+
# Dummy Image containing the required text
|
| 391 |
+
try:
|
| 392 |
+
img = Image.new('RGB', (300, 50), color = (255, 255, 255)) # White background
|
| 393 |
+
from PIL import ImageDraw, ImageFont # Import drawing tools locally
|
| 394 |
+
draw = ImageDraw.Draw(img)
|
| 395 |
+
# Use a basic font if specific ones aren't found
|
| 396 |
+
try: font = ImageFont.truetype("arial.ttf", 15)
|
| 397 |
+
except IOError: font = ImageFont.load_default()
|
| 398 |
+
draw.text((10,10), "Some random info... total items: 7 ... more text", fill=(0,0,0), font=font) # Black text
|
| 399 |
+
img.save("image.png")
|
| 400 |
+
print("Dummy data.csv and image.png created successfully.")
|
| 401 |
+
except ImportError:
|
| 402 |
+
print("Pillow/ImageDraw/ImageFont not installed. Cannot create dummy image file.")
|
| 403 |
+
dummy_files_created = False
|
| 404 |
+
except Exception as img_e:
|
| 405 |
+
print(f"Error creating dummy image: {img_e}")
|
| 406 |
+
dummy_files_created = False
|
| 407 |
+
|
| 408 |
+
except Exception as file_e:
|
| 409 |
+
print(f"Error creating dummy files: {file_e}")
|
| 410 |
+
dummy_files_created = False
|
| 411 |
+
# ---------------------------------------------
|
| 412 |
+
|
| 413 |
+
# --- Run the Test ---
|
| 414 |
+
if dummy_files_created:
|
| 415 |
+
# Call the main function, simulating how the benchmark runner would call it.
|
| 416 |
+
# For this specific question, file_path argument is None as paths are in the question text.
|
| 417 |
+
result = answer_gaia_task(question=test_question, file_path=None)
|
| 418 |
+
|
| 419 |
+
print(f"\n--- Local Test Result ---")
|
| 420 |
+
# Expected answer for dummy files: 3 data rows * 7 = 21
|
| 421 |
+
print(f"Returned Answer: {result}")
|
| 422 |
+
print(f"Expected Answer (for dummy files): 21")
|
| 423 |
else:
|
| 424 |
+
print("Skipping test execution due to issues creating dummy files.")
|
| 425 |
+
|
| 426 |
+
# --- Clean up Dummy Files ---
|
| 427 |
+
print("\nCleaning up dummy files...")
|
| 428 |
+
for dummy_file in ["data.csv", "image.png"]:
|
| 429 |
+
if os.path.exists(dummy_file):
|
| 430 |
+
try: os.remove(dummy_file)
|
| 431 |
+
except Exception as e: print(f"Could not remove {dummy_file}: {e}")
|
| 432 |
+
print("Dummy file cleanup attempted.")
|
|
|
|
|
|
|
|
|