""" Optimized Tool Executor - Parallel execution, auto-retry, and output validation """ import asyncio import logging from dataclasses import dataclass from typing import Any, Dict, List, Optional, Callable from agent.core.level2_config import level2_config from agent.core.semantic_cache import semantic_cache from agent.core.observability import observability, ExecutionEvent logger = logging.getLogger(__name__) @dataclass class ToolCall: """A tool call to execute""" id: str name: str params: Dict[str, Any] estimated_duration: float = 30.0 expected_output: str = "" @dataclass class ToolResult: """Result of a tool execution""" tool_id: str tool_name: str output: Any success: bool execution_time: float from_cache: bool = False retry_count: int = 0 class OptimizedToolExecutor: """ Executes tools with: - Parallel execution when safe - Pre-execution validation - Dynamic timeout adjustment - Automatic retries with backoff - Tool result caching - Output validation & refinement """ def __init__(self): self.config = level2_config self.cache = semantic_cache self.observability = observability def build_dependency_graph(self, tools: List[ToolCall]) -> Dict[str, List[str]]: """Build dependency graph for tool execution""" # For now, assume no dependencies (all can run in parallel) # In future, could analyze params to detect dependencies return {tool.id: [] for tool in tools} def find_parallelizable_tools(self, graph: Dict[str, List[str]]) -> List[List[str]]: """Find batches of tools that can execute in parallel""" # Simple approach: all tools in one batch # More sophisticated: topological sort return [list(graph.keys())] async def execute_tool_with_guarantee( self, tool: ToolCall, execute_fn: Callable[[str, Dict[str, Any]], Any], context: Dict[str, Any] = None ) -> ToolResult: """Execute with automatic retries and timeout adaptation""" max_retries = self.config.max_tool_retries if self.config.enable_auto_retry else 1 base_timeout = tool.estimated_duration * 1.5 # Check cache first if self.config.enable_semantic_cache: cache_key = f"{tool.name}:{str(tool.params)}" cached = await self.cache.check(cache_key) if cached: logger.info(f"Cache hit for tool {tool.name}") return ToolResult( tool_id=tool.id, tool_name=tool.name, output=cached.result, success=True, execution_time=0.0, from_cache=True, retry_count=0 ) for attempt in range(max_retries): timeout = base_timeout * (1.5 ** attempt) # Exponential backoff start_time = asyncio.get_event_loop().time() try: # Track execution start self.observability.track_execution(ExecutionEvent( event_type="tool_execution_start", data={ "tool": tool.name, "attempt": attempt + 1, "timeout": timeout, "params": tool.params } )) # Execute with timeout result = await asyncio.wait_for( execute_fn(tool.name, tool.params), timeout=timeout ) execution_time = asyncio.get_event_loop().time() - start_time # Track execution complete self.observability.track_execution(ExecutionEvent( event_type="tool_execution_complete", data={ "tool": tool.name, "success": True, "duration": execution_time, "output_size": len(str(result)), "cached": False } )) # Cache successful result if self.config.enable_semantic_cache: await self.cache.store( query=f"{tool.name}:{str(tool.params)}", result=result, metadata={ "tool": tool.name, "execution_time": execution_time } ) return ToolResult( tool_id=tool.id, tool_name=tool.name, output=result, success=True, execution_time=execution_time, from_cache=False, retry_count=attempt ) except asyncio.TimeoutError: logger.warning(f"Tool {tool.name} timeout on attempt {attempt + 1}") if attempt < max_retries - 1: self.observability.track_execution(ExecutionEvent( event_type="tool_execution_retry", data={ "tool": tool.name, "reason": "timeout", "attempt": attempt + 1, "next_timeout": base_timeout * (1.5 ** (attempt + 1)) } )) continue else: # Final attempt failed execution_time = asyncio.get_event_loop().time() - start_time self.observability.track_execution(ExecutionEvent( event_type="tool_execution_complete", data={ "tool": tool.name, "success": False, "duration": execution_time, "error": "timeout" } )) return ToolResult( tool_id=tool.id, tool_name=tool.name, output=f"Timeout after {max_retries} attempts", success=False, execution_time=execution_time, from_cache=False, retry_count=attempt ) except Exception as e: logger.error(f"Tool {tool.name} error on attempt {attempt + 1}: {e}") if attempt < max_retries - 1: await asyncio.sleep(1 * (attempt + 1)) # Backoff delay continue else: execution_time = asyncio.get_event_loop().time() - start_time self.observability.track_execution(ExecutionEvent( event_type="tool_execution_complete", data={ "tool": tool.name, "success": False, "duration": execution_time, "error": str(e) } )) return ToolResult( tool_id=tool.id, tool_name=tool.name, output=f"Error: {str(e)}", success=False, execution_time=execution_time, from_cache=False, retry_count=attempt ) # Should never reach here return ToolResult( tool_id=tool.id, tool_name=tool.name, output="Unknown error", success=False, execution_time=0.0, from_cache=False, retry_count=max_retries ) async def execute_with_optimization( self, tools: List[ToolCall], execute_fn: Callable[[str, Dict[str, Any]], Any], context: Dict[str, Any] = None ) -> Dict[str, ToolResult]: """Execute multiple tools with intelligent batching""" if not tools: return {} # Build execution graph graph = self.build_dependency_graph(tools) # Find parallelizable batches batches = self.find_parallelizable_tools(graph) # Track optimization self.observability.track_execution(ExecutionEvent( event_type="execution_optimization", data={ "total_tools": len(tools), "parallelizable": len(tools), # All are parallelizable for now "batches": len(batches) } )) results = {} # Execute batches for batch in batches: # Execute batch in parallel batch_tools = [t for t in tools if t.id in batch] batch_results = await asyncio.gather(*[ self.execute_tool_with_guarantee(tool, execute_fn, context) for tool in batch_tools ]) for result in batch_results: results[result.tool_id] = result return results def validate_output_quality( self, result: ToolResult, expected_output: str ) -> bool: """Validate if output meets quality expectations""" if not result.success: return False # Simple validation: check if output is not empty if not result.output: return False # Check if output contains error indicators output_str = str(result.output).lower() error_indicators = ["error", "exception", "failed", "timeout"] for indicator in error_indicators: if indicator in output_str and len(output_str) < 200: # Short error message return False return True # Global executor tool_executor = OptimizedToolExecutor()