agentbee / src /agent /graph.py
mangubee's picture
Stage 5: Performance optimization - retry logic, Groq integration, improved prompts
5890f66
raw
history blame
21.4 kB
"""
LangGraph Agent Core - StateGraph Definition
Author: @mangobee
Date: 2026-01-01
Stage 1: Skeleton with placeholder nodes
Stage 2: Tool integration (CURRENT)
Stage 3: Planning and reasoning logic implementation
Based on:
- Level 3: Sequential workflow with dynamic planning
- Level 4: Goal-based reasoning, coarse-grained generalist
- Level 6: LangGraph framework
"""
import logging
import os
from typing import TypedDict, List, Optional
from langgraph.graph import StateGraph, END
from src.config import Settings
from src.tools import TOOLS, search, parse_file, safe_eval, analyze_image
from src.agent.llm_client import (
plan_question,
select_tools_with_function_calling,
synthesize_answer,
)
# ============================================================================
# Logging Setup
# ============================================================================
logger = logging.getLogger(__name__)
# ============================================================================
# Helper Functions
# ============================================================================
def is_vision_question(question: str) -> bool:
"""
Detect if question requires vision analysis tool.
Vision questions typically contain keywords about visual content like images, videos, or YouTube links.
Args:
question: GAIA question text
Returns:
True if question likely requires vision tool, False otherwise
"""
vision_keywords = ["image", "video", "youtube", "photo", "picture", "watch", "screenshot", "visual"]
return any(keyword in question.lower() for keyword in vision_keywords)
# ============================================================================
# Agent State Definition
# ============================================================================
class AgentState(TypedDict):
"""
State structure for GAIA agent workflow.
Tracks question processing from input through planning, execution, to final answer.
"""
question: str # Input question from GAIA
file_paths: Optional[List[str]] # Optional file paths for file-based questions
plan: Optional[str] # Generated execution plan (Stage 3)
tool_calls: List[dict] # Tool invocation tracking (Stage 3)
tool_results: List[dict] # Tool execution results (Stage 3)
evidence: List[str] # Evidence collected from tools (Stage 3)
answer: Optional[str] # Final factoid answer
errors: List[str] # Error messages from failures
# ============================================================================
# Environment Validation
# ============================================================================
def validate_environment() -> List[str]:
"""
Check which API keys are available at startup.
Returns:
List of missing API key names (empty if all present)
"""
missing = []
if not os.getenv("GOOGLE_API_KEY"):
missing.append("GOOGLE_API_KEY (Gemini)")
if not os.getenv("HF_TOKEN"):
missing.append("HF_TOKEN (HuggingFace)")
if not os.getenv("ANTHROPIC_API_KEY"):
missing.append("ANTHROPIC_API_KEY (Claude)")
if not os.getenv("TAVILY_API_KEY"):
missing.append("TAVILY_API_KEY (Search)")
return missing
# ============================================================================
# Helper Functions
# ============================================================================
def fallback_tool_selection(question: str, plan: str) -> List[dict]:
"""
MVP Fallback: Simple keyword-based tool selection when LLM fails.
This is a temporary hack to get basic functionality working.
Uses simple keyword matching to select tools.
Args:
question: The user question
plan: The execution plan
Returns:
List of tool calls with basic parameters
"""
logger.info("[fallback_tool_selection] Using keyword-based fallback for tool selection")
tool_calls = []
question_lower = question.lower()
plan_lower = plan.lower()
combined = f"{question_lower} {plan_lower}"
# Search tool: keywords like "search", "find", "look up", "who", "what", "when", "where"
search_keywords = ["search", "find", "look up", "who is", "what is", "when", "where", "google"]
if any(keyword in combined for keyword in search_keywords):
# Extract search query - use first sentence or full question
query = question.split('.')[0] if '.' in question else question
tool_calls.append({
"tool": "web_search",
"params": {"query": query}
})
logger.info(f"[fallback_tool_selection] Added web_search tool with query: {query}")
# Math tool: keywords like "calculate", "compute", "+", "-", "*", "/", "="
math_keywords = ["calculate", "compute", "math", "sum", "multiply", "divide", "+", "-", "*", "/", "="]
if any(keyword in combined for keyword in math_keywords):
# Try to extract expression - look for patterns with numbers and operators
import re
# Look for mathematical expressions
expr_match = re.search(r'[\d\s\+\-\*/\(\)\.]+', question)
if expr_match:
expression = expr_match.group().strip()
tool_calls.append({
"tool": "calculator",
"params": {"expression": expression}
})
logger.info(f"[fallback_tool_selection] Added calculator tool with expression: {expression}")
# File tool: keywords like "file", "parse", "read", "csv", "json", "txt"
file_keywords = ["file", "parse", "read", "csv", "json", "txt", "document"]
if any(keyword in combined for keyword in file_keywords):
# Cannot extract filename without more info, skip for now
logger.warning("[fallback_tool_selection] File operation detected but cannot extract filename")
# Image tool: keywords like "image", "picture", "photo", "analyze", "vision"
image_keywords = ["image", "picture", "photo", "analyze image", "vision"]
if any(keyword in combined for keyword in image_keywords):
# Cannot extract image path without more info, skip for now
logger.warning("[fallback_tool_selection] Image operation detected but cannot extract image path")
if not tool_calls:
logger.warning("[fallback_tool_selection] No tools selected by fallback - adding default search")
# Default: just search the question
tool_calls.append({
"tool": "web_search",
"params": {"query": question}
})
logger.info(f"[fallback_tool_selection] Fallback selected {len(tool_calls)} tool(s)")
return tool_calls
# ============================================================================
# Graph Node Functions (Placeholders for Stage 1)
# ============================================================================
def plan_node(state: AgentState) -> AgentState:
"""
Planning node: Analyze question and generate execution plan.
Stage 3: Dynamic planning with LLM
- LLM analyzes question and available tools
- Generates step-by-step execution plan
- Identifies which tools to use and in what order
Args:
state: Current agent state with question
Returns:
Updated state with execution plan
"""
logger.info(f"[plan_node] ========== PLAN NODE START ==========")
logger.info(f"[plan_node] Question: {state['question']}")
logger.info(f"[plan_node] File paths: {state.get('file_paths')}")
logger.info(f"[plan_node] Available tools: {list(TOOLS.keys())}")
try:
# Stage 3: Use LLM to generate dynamic execution plan
logger.info(f"[plan_node] Calling plan_question() with LLM...")
plan = plan_question(
question=state["question"],
available_tools=TOOLS,
file_paths=state.get("file_paths"),
)
state["plan"] = plan
logger.info(f"[plan_node] ✓ Plan created successfully ({len(plan)} chars)")
logger.debug(f"[plan_node] Plan content: {plan}")
except Exception as e:
logger.error(f"[plan_node] ✗ Planning failed: {type(e).__name__}: {str(e)}", exc_info=True)
state["errors"].append(f"Planning error: {type(e).__name__}: {str(e)}")
state["plan"] = "Error: Unable to create plan"
logger.info(f"[plan_node] ========== PLAN NODE END ==========")
return state
def execute_node(state: AgentState) -> AgentState:
"""
Execution node: Execute tools based on plan.
Stage 3: Dynamic tool selection and execution
- LLM selects tools via function calling
- Extracts parameters from question
- Executes tools and collects results
- Handles errors with retry logic (in tools)
Args:
state: Current agent state with plan
Returns:
Updated state with tool execution results and evidence
"""
logger.info(f"[execute_node] ========== EXECUTE NODE START ==========")
logger.info(f"[execute_node] Plan: {state['plan']}")
logger.info(f"[execute_node] Question: {state['question']}")
# Map tool names to actual functions
# NOTE: Keys must match TOOLS registry in src/tools/__init__.py
TOOL_FUNCTIONS = {
"web_search": search,
"parse_file": parse_file,
"calculator": safe_eval,
"vision": analyze_image,
}
# Initialize results lists
tool_results = []
evidence = []
tool_calls = []
try:
# Stage 3: Use LLM function calling to select tools and extract parameters
logger.info(f"[execute_node] Calling select_tools_with_function_calling()...")
tool_calls = select_tools_with_function_calling(
question=state["question"], plan=state["plan"], available_tools=TOOLS
)
# Validate tool_calls result
if not tool_calls:
logger.warning(f"[execute_node] ⚠ LLM returned empty tool_calls list - using fallback")
state["errors"].append("Tool selection returned no tools - using fallback keyword matching")
# MVP HACK: Use fallback keyword-based tool selection
tool_calls = fallback_tool_selection(state["question"], state["plan"])
logger.info(f"[execute_node] Fallback returned {len(tool_calls)} tool(s)")
elif not isinstance(tool_calls, list):
logger.error(f"[execute_node] ✗ Invalid tool_calls type: {type(tool_calls)} - using fallback")
state["errors"].append(f"Tool selection returned invalid type: {type(tool_calls)} - using fallback")
# MVP HACK: Use fallback
tool_calls = fallback_tool_selection(state["question"], state["plan"])
else:
logger.info(f"[execute_node] ✓ LLM selected {len(tool_calls)} tool(s)")
logger.debug(f"[execute_node] Tool calls: {tool_calls}")
# Execute each tool call
for idx, tool_call in enumerate(tool_calls, 1):
tool_name = tool_call["tool"]
params = tool_call["params"]
logger.info(f"[execute_node] --- Tool {idx}/{len(tool_calls)}: {tool_name} ---")
logger.info(f"[execute_node] Parameters: {params}")
try:
# Get tool function
tool_func = TOOL_FUNCTIONS.get(tool_name)
if not tool_func:
raise ValueError(f"Tool '{tool_name}' not found in TOOL_FUNCTIONS")
# Execute tool
logger.info(f"[execute_node] Executing {tool_name}...")
result = tool_func(**params)
logger.info(f"[execute_node] ✓ {tool_name} completed successfully")
logger.debug(f"[execute_node] Result: {result[:200] if isinstance(result, str) else result}...")
# Store result
tool_results.append(
{
"tool": tool_name,
"params": params,
"result": result,
"status": "success",
}
)
# Extract evidence
evidence.append(f"[{tool_name}] {result}")
except Exception as tool_error:
logger.error(f"[execute_node] ✗ Tool {tool_name} failed: {type(tool_error).__name__}: {str(tool_error)}", exc_info=True)
tool_results.append(
{
"tool": tool_name,
"params": params,
"error": str(tool_error),
"status": "failed",
}
)
# Provide specific error message for vision tool failures
if tool_name == "vision" and ("quota" in str(tool_error).lower() or "429" in str(tool_error)):
state["errors"].append(f"Vision analysis failed: LLM quota exhausted. Vision requires multimodal LLM (Gemini/Claude).")
else:
state["errors"].append(f"Tool {tool_name} failed: {type(tool_error).__name__}: {str(tool_error)}")
logger.info(f"[execute_node] Summary: {len(tool_results)} tool(s) executed, {len(evidence)} evidence items collected")
logger.debug(f"[execute_node] Evidence: {evidence}")
except Exception as e:
logger.error(f"[execute_node] ✗ Execution failed: {type(e).__name__}: {str(e)}", exc_info=True)
# Graceful handling for vision questions when LLMs unavailable
if is_vision_question(state["question"]) and ("quota" in str(e).lower() or "429" in str(e)):
logger.warning(f"[execute_node] Vision question detected with quota error - providing graceful skip")
state["errors"].append("Vision analysis unavailable (LLM quota exhausted). Vision questions require multimodal LLMs.")
else:
state["errors"].append(f"Execution error: {type(e).__name__}: {str(e)}")
# Try fallback if we don't have any tool_calls yet
if not tool_calls:
logger.info(f"[execute_node] Attempting fallback after exception...")
try:
tool_calls = fallback_tool_selection(state["question"], state.get("plan", ""))
logger.info(f"[execute_node] Fallback after exception returned {len(tool_calls)} tool(s)")
# Try to execute fallback tools
# NOTE: Keys must match TOOLS registry in src/tools/__init__.py
TOOL_FUNCTIONS = {
"web_search": search,
"parse_file": parse_file,
"calculator": safe_eval,
"vision": analyze_image,
}
for tool_call in tool_calls:
try:
tool_name = tool_call["tool"]
params = tool_call["params"]
tool_func = TOOL_FUNCTIONS.get(tool_name)
if tool_func:
result = tool_func(**params)
tool_results.append({
"tool": tool_name,
"params": params,
"result": result,
"status": "success"
})
evidence.append(f"[{tool_name}] {result}")
logger.info(f"[execute_node] Fallback tool {tool_name} executed successfully")
except Exception as tool_error:
logger.error(f"[execute_node] Fallback tool {tool_name} failed: {tool_error}")
except Exception as fallback_error:
logger.error(f"[execute_node] Fallback also failed: {fallback_error}")
# Always update state, even if there were errors
state["tool_calls"] = tool_calls
state["tool_results"] = tool_results
state["evidence"] = evidence
logger.info(f"[execute_node] ========== EXECUTE NODE END ==========")
return state
def answer_node(state: AgentState) -> AgentState:
"""
Answer synthesis node: Generate final factoid answer.
Stage 3: Synthesize answer from evidence
- LLM analyzes collected evidence
- Resolves conflicts if present
- Generates factoid answer in GAIA format
Args:
state: Current agent state with evidence from tools
Returns:
Updated state with final factoid answer
"""
logger.info(f"[answer_node] ========== ANSWER NODE START ==========")
logger.info(f"[answer_node] Evidence items collected: {len(state['evidence'])}")
logger.debug(f"[answer_node] Evidence: {state['evidence']}")
logger.info(f"[answer_node] Errors accumulated: {len(state['errors'])}")
if state["errors"]:
logger.warning(f"[answer_node] Error list: {state['errors']}")
try:
# Check if we have evidence
if not state["evidence"]:
logger.warning(
"[answer_node] ✗ No evidence collected, cannot generate answer"
)
# Show WHY it failed - include error details
error_summary = "; ".join(state["errors"]) if state["errors"] else "No errors logged - check API keys and logs"
state["answer"] = f"ERROR: No evidence collected. Details: {error_summary}"
logger.error(f"[answer_node] Returning error answer: {state['answer']}")
return state
# Stage 3: Use LLM to synthesize factoid answer from evidence
logger.info(f"[answer_node] Calling synthesize_answer() with {len(state['evidence'])} evidence items...")
answer = synthesize_answer(
question=state["question"], evidence=state["evidence"]
)
state["answer"] = answer
logger.info(f"[answer_node] ✓ Answer generated successfully: {answer}")
except Exception as e:
logger.error(f"[answer_node] ✗ Answer synthesis failed: {type(e).__name__}: {str(e)}", exc_info=True)
state["errors"].append(f"Answer synthesis error: {type(e).__name__}: {str(e)}")
state["answer"] = f"ERROR: Answer synthesis failed - {type(e).__name__}: {str(e)}"
logger.info(f"[answer_node] ========== ANSWER NODE END ==========")
return state
# ============================================================================
# StateGraph Construction
# ============================================================================
def create_gaia_graph() -> StateGraph:
"""
Create LangGraph StateGraph for GAIA agent.
Implements sequential workflow (Level 3 decision):
question → plan → execute → answer
Returns:
Compiled StateGraph ready for execution
"""
settings = Settings()
# Initialize StateGraph with AgentState
graph = StateGraph(AgentState)
# Add nodes (placeholder implementations)
graph.add_node("plan", plan_node)
graph.add_node("execute", execute_node)
graph.add_node("answer", answer_node)
# Define sequential workflow edges
graph.set_entry_point("plan")
graph.add_edge("plan", "execute")
graph.add_edge("execute", "answer")
graph.add_edge("answer", END)
# Compile graph
compiled_graph = graph.compile()
print("[create_gaia_graph] StateGraph compiled successfully")
return compiled_graph
# ============================================================================
# Agent Wrapper Class
# ============================================================================
class GAIAAgent:
"""
GAIA Benchmark Agent - Main interface.
Wraps LangGraph StateGraph and provides simple call interface.
Compatible with existing BasicAgent interface in app.py.
"""
def __init__(self):
"""Initialize agent and compile StateGraph."""
print("GAIAAgent initializing...")
# Validate environment - check API keys
missing_keys = validate_environment()
if missing_keys:
warning_msg = f"⚠️ WARNING: Missing API keys: {', '.join(missing_keys)}"
print(warning_msg)
logger.warning(warning_msg)
print(" Agent may fail to answer questions. Set keys in environment variables.")
else:
print("✓ All API keys present")
self.graph = create_gaia_graph()
self.last_state = None # Store last execution state for diagnostics
print("GAIAAgent initialized successfully")
def __call__(self, question: str) -> str:
"""
Process question and return answer.
Args:
question: GAIA question text
Returns:
Factoid answer string
"""
print(f"GAIAAgent processing question (first 50 chars): {question[:50]}...")
# Initialize state
initial_state: AgentState = {
"question": question,
"file_paths": None,
"plan": None,
"tool_calls": [],
"tool_results": [],
"evidence": [],
"answer": None,
"errors": [],
}
# Invoke graph
final_state = self.graph.invoke(initial_state)
# Store state for diagnostics
self.last_state = final_state
# Extract answer
answer = final_state.get("answer", "Error: No answer generated")
print(f"GAIAAgent returning answer: {answer}")
return answer