Spaces:
Sleeping
Sleeping
feat: Add AI metadata extraction, latency prediction, context-aware routing, and tool output schemas
d1e5882
| """ | |
| Tool Metadata and Latency Prediction System | |
| Provides: | |
| 1. Per-tool latency predictions (expected latency ranges) | |
| 2. Tool output schemas (strict JSON type definitions) | |
| 3. Context-aware routing hints | |
| """ | |
| from typing import Dict, Any, Optional, List | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| class ToolType(str, Enum): | |
| """Tool type enumeration""" | |
| RAG = "rag" | |
| WEB = "web" | |
| ADMIN = "admin" | |
| LLM = "llm" | |
| class ToolLatencyMetadata: | |
| """Latency metadata for a tool""" | |
| tool_name: str | |
| min_ms: int | |
| max_ms: int | |
| avg_ms: int | |
| description: str | |
| def estimate_latency(self, context: Optional[Dict[str, Any]] = None) -> int: | |
| """ | |
| Estimate expected latency based on context. | |
| Returns estimated latency in milliseconds. | |
| """ | |
| # Base estimate is average | |
| estimate = self.avg_ms | |
| # Context-aware adjustments | |
| if context: | |
| # RAG: Higher latency for longer queries or more chunks | |
| if self.tool_name == "rag": | |
| query_length = context.get("query_length", 0) | |
| if query_length > 100: | |
| estimate = int(self.avg_ms * 1.2) | |
| elif query_length < 20: | |
| estimate = int(self.avg_ms * 0.8) | |
| # Web: Higher latency for complex queries | |
| elif self.tool_name == "web": | |
| query_complexity = context.get("query_complexity", "medium") | |
| if query_complexity == "high": | |
| estimate = int(self.avg_ms * 1.5) | |
| elif query_complexity == "low": | |
| estimate = int(self.avg_ms * 0.7) | |
| return min(max(estimate, self.min_ms), self.max_ms) | |
| class ToolOutputSchema: | |
| """JSON schema definition for tool output""" | |
| tool_name: str | |
| schema: Dict[str, Any] | |
| description: str | |
| example: Dict[str, Any] | |
| # Tool latency metadata | |
| TOOL_LATENCY_METADATA: Dict[str, ToolLatencyMetadata] = { | |
| "rag": ToolLatencyMetadata( | |
| tool_name="rag", | |
| min_ms=60, | |
| max_ms=120, | |
| avg_ms=90, | |
| description="RAG search with vector similarity and re-ranking" | |
| ), | |
| "web": ToolLatencyMetadata( | |
| tool_name="web", | |
| min_ms=400, | |
| max_ms=1800, | |
| avg_ms=800, | |
| description="Web search via Google Custom Search API" | |
| ), | |
| "admin": ToolLatencyMetadata( | |
| tool_name="admin", | |
| min_ms=5, | |
| max_ms=20, | |
| avg_ms=10, | |
| description="Admin rule checking and violation logging" | |
| ), | |
| "llm": ToolLatencyMetadata( | |
| tool_name="llm", | |
| min_ms=500, | |
| max_ms=5000, | |
| avg_ms=2000, | |
| description="LLM generation and reasoning" | |
| ) | |
| } | |
| # Tool output schemas (JSON Schema format) | |
| TOOL_OUTPUT_SCHEMAS: Dict[str, ToolOutputSchema] = { | |
| "rag": ToolOutputSchema( | |
| tool_name="rag", | |
| schema={ | |
| "type": "object", | |
| "required": ["results", "query", "tenant_id"], | |
| "properties": { | |
| "results": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "required": ["text", "similarity"], | |
| "properties": { | |
| "text": {"type": "string"}, | |
| "similarity": {"type": "number", "minimum": 0, "maximum": 1}, | |
| "metadata": {"type": "object"}, | |
| "doc_id": {"type": "string"} | |
| } | |
| } | |
| }, | |
| "query": {"type": "string"}, | |
| "tenant_id": {"type": "string"}, | |
| "hits_count": {"type": "integer"}, | |
| "avg_score": {"type": "number"}, | |
| "top_score": {"type": "number"}, | |
| "latency_ms": {"type": "integer"} | |
| } | |
| }, | |
| description="RAG search results with similarity scores", | |
| example={ | |
| "results": [ | |
| { | |
| "text": "Document chunk text...", | |
| "similarity": 0.85, | |
| "metadata": {"title": "API Docs", "source_type": "pdf"}, | |
| "doc_id": "doc123" | |
| } | |
| ], | |
| "query": "user query", | |
| "tenant_id": "tenant1", | |
| "hits_count": 3, | |
| "avg_score": 0.75, | |
| "top_score": 0.85, | |
| "latency_ms": 90 | |
| } | |
| ), | |
| "web": ToolOutputSchema( | |
| tool_name="web", | |
| schema={ | |
| "type": "object", | |
| "required": ["results", "query"], | |
| "properties": { | |
| "results": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "required": ["title", "snippet", "link"], | |
| "properties": { | |
| "title": {"type": "string"}, | |
| "snippet": {"type": "string"}, | |
| "link": {"type": "string"}, | |
| "displayLink": {"type": "string"} | |
| } | |
| } | |
| }, | |
| "query": {"type": "string"}, | |
| "total_results": {"type": "integer"}, | |
| "latency_ms": {"type": "integer"} | |
| } | |
| }, | |
| description="Web search results from Google Custom Search", | |
| example={ | |
| "results": [ | |
| { | |
| "title": "Search Result Title", | |
| "snippet": "Result snippet text...", | |
| "link": "https://example.com", | |
| "displayLink": "example.com" | |
| } | |
| ], | |
| "query": "search query", | |
| "total_results": 10, | |
| "latency_ms": 800 | |
| } | |
| ), | |
| "admin": ToolOutputSchema( | |
| tool_name="admin", | |
| schema={ | |
| "type": "object", | |
| "required": ["violations", "checked"], | |
| "properties": { | |
| "violations": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "required": ["rule_id", "severity", "matched_text"], | |
| "properties": { | |
| "rule_id": {"type": "string"}, | |
| "rule_pattern": {"type": "string"}, | |
| "severity": {"type": "string", "enum": ["low", "medium", "high", "critical"]}, | |
| "matched_text": {"type": "string"}, | |
| "confidence": {"type": "number", "minimum": 0, "maximum": 1}, | |
| "message_preview": {"type": "string"} | |
| } | |
| } | |
| }, | |
| "checked": {"type": "boolean"}, | |
| "rules_count": {"type": "integer"}, | |
| "latency_ms": {"type": "integer"} | |
| } | |
| }, | |
| description="Admin rule violations and safety checks", | |
| example={ | |
| "violations": [ | |
| { | |
| "rule_id": "rule1", | |
| "rule_pattern": ".*password.*", | |
| "severity": "high", | |
| "matched_text": "password", | |
| "confidence": 0.95, | |
| "message_preview": "User asked for password" | |
| } | |
| ], | |
| "checked": True, | |
| "rules_count": 5, | |
| "latency_ms": 10 | |
| } | |
| ), | |
| "llm": ToolOutputSchema( | |
| tool_name="llm", | |
| schema={ | |
| "type": "object", | |
| "required": ["text", "tokens_used"], | |
| "properties": { | |
| "text": {"type": "string"}, | |
| "tokens_used": {"type": "integer"}, | |
| "latency_ms": {"type": "integer"}, | |
| "model": {"type": "string"}, | |
| "temperature": {"type": "number"} | |
| } | |
| }, | |
| description="LLM-generated response", | |
| example={ | |
| "text": "Generated response text...", | |
| "tokens_used": 150, | |
| "latency_ms": 2000, | |
| "model": "llama3.1:latest", | |
| "temperature": 0.0 | |
| } | |
| ) | |
| } | |
| def get_tool_latency_estimate(tool_name: str, context: Optional[Dict[str, Any]] = None) -> int: | |
| """ | |
| Get estimated latency for a tool in milliseconds. | |
| Args: | |
| tool_name: Name of the tool (rag, web, admin, llm) | |
| context: Optional context for more accurate estimation | |
| Returns: | |
| Estimated latency in milliseconds | |
| """ | |
| metadata = TOOL_LATENCY_METADATA.get(tool_name) | |
| if not metadata: | |
| # Default estimate for unknown tools | |
| return 1000 | |
| return metadata.estimate_latency(context) | |
| def get_tool_schema(tool_name: str) -> Optional[ToolOutputSchema]: | |
| """Get the output schema for a tool""" | |
| return TOOL_OUTPUT_SCHEMAS.get(tool_name) | |
| def validate_tool_output(tool_name: str, output: Dict[str, Any]) -> tuple[bool, Optional[str]]: | |
| """ | |
| Validate tool output against its schema. | |
| Returns: | |
| (is_valid, error_message) | |
| """ | |
| schema_obj = get_tool_schema(tool_name) | |
| if not schema_obj: | |
| return True, None # Unknown tool, skip validation | |
| # Simple validation (full JSON Schema validation would require jsonschema library) | |
| schema = schema_obj.schema | |
| required = schema.get("required", []) | |
| for field in required: | |
| if field not in output: | |
| return False, f"Missing required field: {field}" | |
| # Type checking for top-level fields | |
| properties = schema.get("properties", {}) | |
| for field, value in output.items(): | |
| if field in properties: | |
| expected_type = properties[field].get("type") | |
| if expected_type: | |
| if expected_type == "array" and not isinstance(value, list): | |
| return False, f"Field '{field}' must be array, got {type(value).__name__}" | |
| elif expected_type == "object" and not isinstance(value, dict): | |
| return False, f"Field '{field}' must be object, got {type(value).__name__}" | |
| elif expected_type == "string" and not isinstance(value, str): | |
| return False, f"Field '{field}' must be string, got {type(value).__name__}" | |
| elif expected_type == "integer" and not isinstance(value, int): | |
| return False, f"Field '{field}' must be integer, got {type(value).__name__}" | |
| elif expected_type == "number" and not isinstance(value, (int, float)): | |
| return False, f"Field '{field}' must be number, got {type(value).__name__}" | |
| elif expected_type == "boolean" and not isinstance(value, bool): | |
| return False, f"Field '{field}' must be boolean, got {type(value).__name__}" | |
| return True, None | |
| def estimate_path_latency(tool_sequence: List[str], context: Optional[Dict[str, Any]] = None) -> int: | |
| """ | |
| Estimate total latency for a sequence of tools. | |
| Args: | |
| tool_sequence: List of tool names in execution order | |
| context: Optional context for each tool | |
| Returns: | |
| Total estimated latency in milliseconds | |
| """ | |
| total = 0 | |
| for tool in tool_sequence: | |
| tool_context = context.get(tool, {}) if context else {} | |
| total += get_tool_latency_estimate(tool, tool_context) | |
| return total | |
| def get_fastest_path( | |
| required_tools: List[str], | |
| context: Optional[Dict[str, Any]] = None | |
| ) -> List[str]: | |
| """ | |
| Determine the fastest execution order for required tools. | |
| Currently tools are executed sequentially, but this could be extended | |
| to suggest parallel execution for independent tools. | |
| Args: | |
| required_tools: List of required tool names | |
| context: Optional context for latency estimation | |
| Returns: | |
| Optimized tool sequence | |
| """ | |
| # Sort by estimated latency (fastest first) | |
| tool_latencies = [ | |
| (tool, get_tool_latency_estimate(tool, context.get(tool, {}) if context else {})) | |
| for tool in required_tools | |
| ] | |
| tool_latencies.sort(key=lambda x: x[1]) | |
| return [tool for tool, _ in tool_latencies] | |