File size: 7,699 Bytes
225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e 64e1704 225a75e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
#!/usr/bin/env python3
"""
LangGraph State Schema for GAIA Agent System
Defines the state structure for agent communication and coordination
"""
from typing import Dict, Any, List, Optional, Literal
from dataclasses import dataclass, field
from enum import Enum
import time
import uuid
class QuestionType(Enum):
"""Classification of GAIA question types"""
WIKIPEDIA = "wikipedia"
WEB_RESEARCH = "web_research"
YOUTUBE = "youtube"
FILE_PROCESSING = "file_processing"
MATHEMATICAL = "mathematical"
CODE_EXECUTION = "code_execution"
TEXT_MANIPULATION = "text_manipulation"
REASONING = "reasoning"
UNKNOWN = "unknown"
class ModelTier(Enum):
"""Model complexity tiers"""
ROUTER = "router" # 7B - Fast classification
MAIN = "main" # 32B - Balanced tasks
COMPLEX = "complex" # 72B - Complex reasoning
class AgentRole(Enum):
"""Roles of different agents in the system"""
ROUTER = "router"
WEB_RESEARCHER = "web_researcher"
FILE_PROCESSOR = "file_processor"
CODE_EXECUTOR = "code_executor"
REASONING_AGENT = "reasoning_agent"
SYNTHESIZER = "synthesizer"
@dataclass
class ToolResult:
"""Result from a tool execution"""
tool_name: str
success: bool
result: Any
error: Optional[str] = None
execution_time: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class AgentResult:
"""Result from an agent's processing"""
agent_role: AgentRole
success: bool
result: str
confidence: float # 0.0 to 1.0
reasoning: str
tools_used: List[ToolResult] = field(default_factory=list)
model_used: str = ""
processing_time: float = 0.0
cost_estimate: float = 0.0
class GAIAAgentState:
"""
Central state for the GAIA agent system
This is passed between all agents in the LangGraph workflow
"""
def __init__(self, question: str, question_id: str = None, file_name: str = None, file_content: bytes = None):
self.question = question
self.question_id = question_id or str(uuid.uuid4())
self.file_name = file_name
self.file_content = file_content
# Analysis results
self.question_type: Optional[QuestionType] = None
self.question_types: List[QuestionType] = []
self.primary_question_type: Optional[QuestionType] = None
self.complexity_assessment: str = "medium"
self.selected_agents: List[AgentRole] = []
# Enhanced router analysis
self.router_analysis: Optional[Dict[str, Any]] = None
self.agent_sequence: List[str] = []
# Processing tracking
self.processing_steps: List[str] = []
self.agent_results: List[AgentResult] = []
self.errors: List[str] = []
self.start_time: float = time.time()
self.total_cost: float = 0.0
# Final results
self.final_answer: Optional[str] = None
self.final_confidence: float = 0.0
self.synthesis_reasoning: str = ""
# Routing decision tracking
self.routing_decision: Dict[str, Any] = {}
# Question information
self.difficulty_level: int = 1 # 1, 2, or 3
self.file_path: Optional[str] = None
self.metadata: Dict[str, Any] = {}
# Routing decisions
self.estimated_cost: float = 0.0
# Agent results
self.tool_results: List[ToolResult] = []
# Final answer
self.answer_source: str = "" # Which agent provided the final answer
# System tracking
self.total_processing_time: float = 0.0
# Status flags
self.is_complete: bool = False
self.requires_human_review: bool = False
self.confidence_threshold_met: bool = False
def add_processing_step(self, step: str):
"""Add a processing step to the history"""
self.processing_steps.append(f"[{time.time() - self.start_time:.2f}s] {step}")
def add_agent_result(self, result: AgentResult):
"""Add result from an agent"""
self.agent_results.append(result)
self.total_cost += result.cost_estimate
self.total_processing_time += result.processing_time
self.add_processing_step(f"{result.agent_role.value}: {result.result[:50]}...")
def add_tool_result(self, result: ToolResult):
"""Add result from a tool execution"""
self.tool_results.append(result)
self.add_processing_step(f"Tool {result.tool_name}: {'✅' if result.success else '❌'}")
def add_error(self, error_message: str):
"""Add an error message"""
self.errors.append(error_message)
self.add_processing_step(f"ERROR: {error_message}")
def get_best_result(self) -> Optional[AgentResult]:
"""Get the agent result with highest confidence"""
if not self.agent_results:
return None
return max(self.agent_results, key=lambda r: r.confidence)
def should_use_complex_model(self) -> bool:
"""Determine if complex model should be used based on state"""
# Use complex model for:
# - High difficulty questions
# - Questions requiring detailed reasoning
# - When we have budget remaining
return (
self.difficulty_level >= 3 or
self.complexity_assessment == "complex" or
any("reasoning" in step.lower() for step in self.processing_steps)
) and self.total_cost < 0.07 # Keep some budget for complex tasks
def get_summary(self) -> Dict[str, Any]:
"""Get a summary of the current state"""
return {
"task_id": self.question_id,
"question_type": self.question_type.value if self.question_type else "unknown",
"agents_used": [role.value for role in self.selected_agents],
"tools_used": [tool.tool_name for tool in self.tool_results],
"final_answer": self.final_answer,
"confidence": self.final_confidence,
"processing_time": self.total_processing_time,
"total_cost": self.total_cost,
"steps_count": len(self.processing_steps),
"is_complete": self.is_complete,
"error_count": len(self.errors)
}
def to_dict(self) -> Dict[str, Any]:
"""Convert state to dictionary for serialization"""
return {
"task_id": self.question_id,
"question": self.question,
"question_type": self.question_type.value if self.question_type else "unknown",
"difficulty_level": self.difficulty_level,
"file_name": self.file_name,
"file_path": self.file_path,
"routing_decision": self.routing_decision,
"selected_agents": [agent.value for agent in self.selected_agents],
"complexity_assessment": self.complexity_assessment,
"final_answer": self.final_answer,
"final_confidence": self.final_confidence,
"final_reasoning": self.synthesis_reasoning,
"answer_source": self.answer_source,
"processing_steps": self.processing_steps,
"total_cost": self.total_cost,
"total_processing_time": self.total_processing_time,
"error_messages": self.errors,
"is_complete": self.is_complete,
"summary": self.get_summary(),
"router_analysis": self.router_analysis,
"agent_sequence": self.agent_sequence
}
# Type alias for LangGraph
AgentState = GAIAAgentState |