| | """ |
| | Optimized Tool Executor - Parallel execution, auto-retry, and output validation |
| | """ |
| |
|
| | import asyncio |
| | import logging |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional, Callable |
| |
|
| | from agent.core.level2_config import level2_config |
| | from agent.core.semantic_cache import semantic_cache |
| | from agent.core.observability import observability, ExecutionEvent |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class ToolCall: |
| | """A tool call to execute""" |
| | id: str |
| | name: str |
| | params: Dict[str, Any] |
| | estimated_duration: float = 30.0 |
| | expected_output: str = "" |
| |
|
| |
|
| | @dataclass |
| | class ToolResult: |
| | """Result of a tool execution""" |
| | tool_id: str |
| | tool_name: str |
| | output: Any |
| | success: bool |
| | execution_time: float |
| | from_cache: bool = False |
| | retry_count: int = 0 |
| |
|
| |
|
| | class OptimizedToolExecutor: |
| | """ |
| | Executes tools with: |
| | - Parallel execution when safe |
| | - Pre-execution validation |
| | - Dynamic timeout adjustment |
| | - Automatic retries with backoff |
| | - Tool result caching |
| | - Output validation & refinement |
| | """ |
| | |
| | def __init__(self): |
| | self.config = level2_config |
| | self.cache = semantic_cache |
| | self.observability = observability |
| | |
| | def build_dependency_graph(self, tools: List[ToolCall]) -> Dict[str, List[str]]: |
| | """Build dependency graph for tool execution""" |
| | |
| | |
| | return {tool.id: [] for tool in tools} |
| | |
| | def find_parallelizable_tools(self, graph: Dict[str, List[str]]) -> List[List[str]]: |
| | """Find batches of tools that can execute in parallel""" |
| | |
| | |
| | return [list(graph.keys())] |
| | |
| | async def execute_tool_with_guarantee( |
| | self, |
| | tool: ToolCall, |
| | execute_fn: Callable[[str, Dict[str, Any]], Any], |
| | context: Dict[str, Any] = None |
| | ) -> ToolResult: |
| | """Execute with automatic retries and timeout adaptation""" |
| | |
| | max_retries = self.config.max_tool_retries if self.config.enable_auto_retry else 1 |
| | base_timeout = tool.estimated_duration * 1.5 |
| | |
| | |
| | if self.config.enable_semantic_cache: |
| | cache_key = f"{tool.name}:{str(tool.params)}" |
| | cached = await self.cache.check(cache_key) |
| | if cached: |
| | logger.info(f"Cache hit for tool {tool.name}") |
| | return ToolResult( |
| | tool_id=tool.id, |
| | tool_name=tool.name, |
| | output=cached.result, |
| | success=True, |
| | execution_time=0.0, |
| | from_cache=True, |
| | retry_count=0 |
| | ) |
| | |
| | for attempt in range(max_retries): |
| | timeout = base_timeout * (1.5 ** attempt) |
| | |
| | start_time = asyncio.get_event_loop().time() |
| | |
| | try: |
| | |
| | self.observability.track_execution(ExecutionEvent( |
| | event_type="tool_execution_start", |
| | data={ |
| | "tool": tool.name, |
| | "attempt": attempt + 1, |
| | "timeout": timeout, |
| | "params": tool.params |
| | } |
| | )) |
| | |
| | |
| | result = await asyncio.wait_for( |
| | execute_fn(tool.name, tool.params), |
| | timeout=timeout |
| | ) |
| | |
| | execution_time = asyncio.get_event_loop().time() - start_time |
| | |
| | |
| | self.observability.track_execution(ExecutionEvent( |
| | event_type="tool_execution_complete", |
| | data={ |
| | "tool": tool.name, |
| | "success": True, |
| | "duration": execution_time, |
| | "output_size": len(str(result)), |
| | "cached": False |
| | } |
| | )) |
| | |
| | |
| | if self.config.enable_semantic_cache: |
| | await self.cache.store( |
| | query=f"{tool.name}:{str(tool.params)}", |
| | result=result, |
| | metadata={ |
| | "tool": tool.name, |
| | "execution_time": execution_time |
| | } |
| | ) |
| | |
| | return ToolResult( |
| | tool_id=tool.id, |
| | tool_name=tool.name, |
| | output=result, |
| | success=True, |
| | execution_time=execution_time, |
| | from_cache=False, |
| | retry_count=attempt |
| | ) |
| | |
| | except asyncio.TimeoutError: |
| | logger.warning(f"Tool {tool.name} timeout on attempt {attempt + 1}") |
| | |
| | if attempt < max_retries - 1: |
| | self.observability.track_execution(ExecutionEvent( |
| | event_type="tool_execution_retry", |
| | data={ |
| | "tool": tool.name, |
| | "reason": "timeout", |
| | "attempt": attempt + 1, |
| | "next_timeout": base_timeout * (1.5 ** (attempt + 1)) |
| | } |
| | )) |
| | continue |
| | else: |
| | |
| | execution_time = asyncio.get_event_loop().time() - start_time |
| | |
| | self.observability.track_execution(ExecutionEvent( |
| | event_type="tool_execution_complete", |
| | data={ |
| | "tool": tool.name, |
| | "success": False, |
| | "duration": execution_time, |
| | "error": "timeout" |
| | } |
| | )) |
| | |
| | return ToolResult( |
| | tool_id=tool.id, |
| | tool_name=tool.name, |
| | output=f"Timeout after {max_retries} attempts", |
| | success=False, |
| | execution_time=execution_time, |
| | from_cache=False, |
| | retry_count=attempt |
| | ) |
| | |
| | except Exception as e: |
| | logger.error(f"Tool {tool.name} error on attempt {attempt + 1}: {e}") |
| | |
| | if attempt < max_retries - 1: |
| | await asyncio.sleep(1 * (attempt + 1)) |
| | continue |
| | else: |
| | execution_time = asyncio.get_event_loop().time() - start_time |
| | |
| | self.observability.track_execution(ExecutionEvent( |
| | event_type="tool_execution_complete", |
| | data={ |
| | "tool": tool.name, |
| | "success": False, |
| | "duration": execution_time, |
| | "error": str(e) |
| | } |
| | )) |
| | |
| | return ToolResult( |
| | tool_id=tool.id, |
| | tool_name=tool.name, |
| | output=f"Error: {str(e)}", |
| | success=False, |
| | execution_time=execution_time, |
| | from_cache=False, |
| | retry_count=attempt |
| | ) |
| | |
| | |
| | return ToolResult( |
| | tool_id=tool.id, |
| | tool_name=tool.name, |
| | output="Unknown error", |
| | success=False, |
| | execution_time=0.0, |
| | from_cache=False, |
| | retry_count=max_retries |
| | ) |
| | |
| | async def execute_with_optimization( |
| | self, |
| | tools: List[ToolCall], |
| | execute_fn: Callable[[str, Dict[str, Any]], Any], |
| | context: Dict[str, Any] = None |
| | ) -> Dict[str, ToolResult]: |
| | """Execute multiple tools with intelligent batching""" |
| | |
| | if not tools: |
| | return {} |
| | |
| | |
| | graph = self.build_dependency_graph(tools) |
| | |
| | |
| | batches = self.find_parallelizable_tools(graph) |
| | |
| | |
| | self.observability.track_execution(ExecutionEvent( |
| | event_type="execution_optimization", |
| | data={ |
| | "total_tools": len(tools), |
| | "parallelizable": len(tools), |
| | "batches": len(batches) |
| | } |
| | )) |
| | |
| | results = {} |
| | |
| | |
| | for batch in batches: |
| | |
| | batch_tools = [t for t in tools if t.id in batch] |
| | |
| | batch_results = await asyncio.gather(*[ |
| | self.execute_tool_with_guarantee(tool, execute_fn, context) |
| | for tool in batch_tools |
| | ]) |
| | |
| | for result in batch_results: |
| | results[result.tool_id] = result |
| | |
| | return results |
| | |
| | def validate_output_quality( |
| | self, |
| | result: ToolResult, |
| | expected_output: str |
| | ) -> bool: |
| | """Validate if output meets quality expectations""" |
| | if not result.success: |
| | return False |
| | |
| | |
| | if not result.output: |
| | return False |
| | |
| | |
| | output_str = str(result.output).lower() |
| | error_indicators = ["error", "exception", "failed", "timeout"] |
| | |
| | for indicator in error_indicators: |
| | if indicator in output_str and len(output_str) < 200: |
| | |
| | return False |
| | |
| | return True |
| |
|
| |
|
| | |
| | tool_executor = OptimizedToolExecutor() |
| |
|