water3 / agent /core /optimized_executor.py
onewayto's picture
Upload 187 files
070daf8 verified
"""
Optimized Tool Executor - Parallel execution, auto-retry, and output validation
"""
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Callable
from agent.core.level2_config import level2_config
from agent.core.semantic_cache import semantic_cache
from agent.core.observability import observability, ExecutionEvent
logger = logging.getLogger(__name__)
@dataclass
class ToolCall:
"""A tool call to execute"""
id: str
name: str
params: Dict[str, Any]
estimated_duration: float = 30.0
expected_output: str = ""
@dataclass
class ToolResult:
"""Result of a tool execution"""
tool_id: str
tool_name: str
output: Any
success: bool
execution_time: float
from_cache: bool = False
retry_count: int = 0
class OptimizedToolExecutor:
"""
Executes tools with:
- Parallel execution when safe
- Pre-execution validation
- Dynamic timeout adjustment
- Automatic retries with backoff
- Tool result caching
- Output validation & refinement
"""
def __init__(self):
self.config = level2_config
self.cache = semantic_cache
self.observability = observability
def build_dependency_graph(self, tools: List[ToolCall]) -> Dict[str, List[str]]:
"""Build dependency graph for tool execution"""
# For now, assume no dependencies (all can run in parallel)
# In future, could analyze params to detect dependencies
return {tool.id: [] for tool in tools}
def find_parallelizable_tools(self, graph: Dict[str, List[str]]) -> List[List[str]]:
"""Find batches of tools that can execute in parallel"""
# Simple approach: all tools in one batch
# More sophisticated: topological sort
return [list(graph.keys())]
async def execute_tool_with_guarantee(
self,
tool: ToolCall,
execute_fn: Callable[[str, Dict[str, Any]], Any],
context: Dict[str, Any] = None
) -> ToolResult:
"""Execute with automatic retries and timeout adaptation"""
max_retries = self.config.max_tool_retries if self.config.enable_auto_retry else 1
base_timeout = tool.estimated_duration * 1.5
# Check cache first
if self.config.enable_semantic_cache:
cache_key = f"{tool.name}:{str(tool.params)}"
cached = await self.cache.check(cache_key)
if cached:
logger.info(f"Cache hit for tool {tool.name}")
return ToolResult(
tool_id=tool.id,
tool_name=tool.name,
output=cached.result,
success=True,
execution_time=0.0,
from_cache=True,
retry_count=0
)
for attempt in range(max_retries):
timeout = base_timeout * (1.5 ** attempt) # Exponential backoff
start_time = asyncio.get_event_loop().time()
try:
# Track execution start
self.observability.track_execution(ExecutionEvent(
event_type="tool_execution_start",
data={
"tool": tool.name,
"attempt": attempt + 1,
"timeout": timeout,
"params": tool.params
}
))
# Execute with timeout
result = await asyncio.wait_for(
execute_fn(tool.name, tool.params),
timeout=timeout
)
execution_time = asyncio.get_event_loop().time() - start_time
# Track execution complete
self.observability.track_execution(ExecutionEvent(
event_type="tool_execution_complete",
data={
"tool": tool.name,
"success": True,
"duration": execution_time,
"output_size": len(str(result)),
"cached": False
}
))
# Cache successful result
if self.config.enable_semantic_cache:
await self.cache.store(
query=f"{tool.name}:{str(tool.params)}",
result=result,
metadata={
"tool": tool.name,
"execution_time": execution_time
}
)
return ToolResult(
tool_id=tool.id,
tool_name=tool.name,
output=result,
success=True,
execution_time=execution_time,
from_cache=False,
retry_count=attempt
)
except asyncio.TimeoutError:
logger.warning(f"Tool {tool.name} timeout on attempt {attempt + 1}")
if attempt < max_retries - 1:
self.observability.track_execution(ExecutionEvent(
event_type="tool_execution_retry",
data={
"tool": tool.name,
"reason": "timeout",
"attempt": attempt + 1,
"next_timeout": base_timeout * (1.5 ** (attempt + 1))
}
))
continue
else:
# Final attempt failed
execution_time = asyncio.get_event_loop().time() - start_time
self.observability.track_execution(ExecutionEvent(
event_type="tool_execution_complete",
data={
"tool": tool.name,
"success": False,
"duration": execution_time,
"error": "timeout"
}
))
return ToolResult(
tool_id=tool.id,
tool_name=tool.name,
output=f"Timeout after {max_retries} attempts",
success=False,
execution_time=execution_time,
from_cache=False,
retry_count=attempt
)
except Exception as e:
logger.error(f"Tool {tool.name} error on attempt {attempt + 1}: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1)) # Backoff delay
continue
else:
execution_time = asyncio.get_event_loop().time() - start_time
self.observability.track_execution(ExecutionEvent(
event_type="tool_execution_complete",
data={
"tool": tool.name,
"success": False,
"duration": execution_time,
"error": str(e)
}
))
return ToolResult(
tool_id=tool.id,
tool_name=tool.name,
output=f"Error: {str(e)}",
success=False,
execution_time=execution_time,
from_cache=False,
retry_count=attempt
)
# Should never reach here
return ToolResult(
tool_id=tool.id,
tool_name=tool.name,
output="Unknown error",
success=False,
execution_time=0.0,
from_cache=False,
retry_count=max_retries
)
async def execute_with_optimization(
self,
tools: List[ToolCall],
execute_fn: Callable[[str, Dict[str, Any]], Any],
context: Dict[str, Any] = None
) -> Dict[str, ToolResult]:
"""Execute multiple tools with intelligent batching"""
if not tools:
return {}
# Build execution graph
graph = self.build_dependency_graph(tools)
# Find parallelizable batches
batches = self.find_parallelizable_tools(graph)
# Track optimization
self.observability.track_execution(ExecutionEvent(
event_type="execution_optimization",
data={
"total_tools": len(tools),
"parallelizable": len(tools), # All are parallelizable for now
"batches": len(batches)
}
))
results = {}
# Execute batches
for batch in batches:
# Execute batch in parallel
batch_tools = [t for t in tools if t.id in batch]
batch_results = await asyncio.gather(*[
self.execute_tool_with_guarantee(tool, execute_fn, context)
for tool in batch_tools
])
for result in batch_results:
results[result.tool_id] = result
return results
def validate_output_quality(
self,
result: ToolResult,
expected_output: str
) -> bool:
"""Validate if output meets quality expectations"""
if not result.success:
return False
# Simple validation: check if output is not empty
if not result.output:
return False
# Check if output contains error indicators
output_str = str(result.output).lower()
error_indicators = ["error", "exception", "failed", "timeout"]
for indicator in error_indicators:
if indicator in output_str and len(output_str) < 200:
# Short error message
return False
return True
# Global executor
tool_executor = OptimizedToolExecutor()