IntegraChat / backend /api /services /tool_metadata.py
nothingworry's picture
feat: Add AI metadata extraction, latency prediction, context-aware routing, and tool output schemas
d1e5882
raw
history blame
12.3 kB
"""
Tool Metadata and Latency Prediction System
Provides:
1. Per-tool latency predictions (expected latency ranges)
2. Tool output schemas (strict JSON type definitions)
3. Context-aware routing hints
"""
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum
class ToolType(str, Enum):
"""Tool type enumeration"""
RAG = "rag"
WEB = "web"
ADMIN = "admin"
LLM = "llm"
@dataclass
class ToolLatencyMetadata:
"""Latency metadata for a tool"""
tool_name: str
min_ms: int
max_ms: int
avg_ms: int
description: str
def estimate_latency(self, context: Optional[Dict[str, Any]] = None) -> int:
"""
Estimate expected latency based on context.
Returns estimated latency in milliseconds.
"""
# Base estimate is average
estimate = self.avg_ms
# Context-aware adjustments
if context:
# RAG: Higher latency for longer queries or more chunks
if self.tool_name == "rag":
query_length = context.get("query_length", 0)
if query_length > 100:
estimate = int(self.avg_ms * 1.2)
elif query_length < 20:
estimate = int(self.avg_ms * 0.8)
# Web: Higher latency for complex queries
elif self.tool_name == "web":
query_complexity = context.get("query_complexity", "medium")
if query_complexity == "high":
estimate = int(self.avg_ms * 1.5)
elif query_complexity == "low":
estimate = int(self.avg_ms * 0.7)
return min(max(estimate, self.min_ms), self.max_ms)
@dataclass
class ToolOutputSchema:
"""JSON schema definition for tool output"""
tool_name: str
schema: Dict[str, Any]
description: str
example: Dict[str, Any]
# Tool latency metadata
TOOL_LATENCY_METADATA: Dict[str, ToolLatencyMetadata] = {
"rag": ToolLatencyMetadata(
tool_name="rag",
min_ms=60,
max_ms=120,
avg_ms=90,
description="RAG search with vector similarity and re-ranking"
),
"web": ToolLatencyMetadata(
tool_name="web",
min_ms=400,
max_ms=1800,
avg_ms=800,
description="Web search via Google Custom Search API"
),
"admin": ToolLatencyMetadata(
tool_name="admin",
min_ms=5,
max_ms=20,
avg_ms=10,
description="Admin rule checking and violation logging"
),
"llm": ToolLatencyMetadata(
tool_name="llm",
min_ms=500,
max_ms=5000,
avg_ms=2000,
description="LLM generation and reasoning"
)
}
# Tool output schemas (JSON Schema format)
TOOL_OUTPUT_SCHEMAS: Dict[str, ToolOutputSchema] = {
"rag": ToolOutputSchema(
tool_name="rag",
schema={
"type": "object",
"required": ["results", "query", "tenant_id"],
"properties": {
"results": {
"type": "array",
"items": {
"type": "object",
"required": ["text", "similarity"],
"properties": {
"text": {"type": "string"},
"similarity": {"type": "number", "minimum": 0, "maximum": 1},
"metadata": {"type": "object"},
"doc_id": {"type": "string"}
}
}
},
"query": {"type": "string"},
"tenant_id": {"type": "string"},
"hits_count": {"type": "integer"},
"avg_score": {"type": "number"},
"top_score": {"type": "number"},
"latency_ms": {"type": "integer"}
}
},
description="RAG search results with similarity scores",
example={
"results": [
{
"text": "Document chunk text...",
"similarity": 0.85,
"metadata": {"title": "API Docs", "source_type": "pdf"},
"doc_id": "doc123"
}
],
"query": "user query",
"tenant_id": "tenant1",
"hits_count": 3,
"avg_score": 0.75,
"top_score": 0.85,
"latency_ms": 90
}
),
"web": ToolOutputSchema(
tool_name="web",
schema={
"type": "object",
"required": ["results", "query"],
"properties": {
"results": {
"type": "array",
"items": {
"type": "object",
"required": ["title", "snippet", "link"],
"properties": {
"title": {"type": "string"},
"snippet": {"type": "string"},
"link": {"type": "string"},
"displayLink": {"type": "string"}
}
}
},
"query": {"type": "string"},
"total_results": {"type": "integer"},
"latency_ms": {"type": "integer"}
}
},
description="Web search results from Google Custom Search",
example={
"results": [
{
"title": "Search Result Title",
"snippet": "Result snippet text...",
"link": "https://example.com",
"displayLink": "example.com"
}
],
"query": "search query",
"total_results": 10,
"latency_ms": 800
}
),
"admin": ToolOutputSchema(
tool_name="admin",
schema={
"type": "object",
"required": ["violations", "checked"],
"properties": {
"violations": {
"type": "array",
"items": {
"type": "object",
"required": ["rule_id", "severity", "matched_text"],
"properties": {
"rule_id": {"type": "string"},
"rule_pattern": {"type": "string"},
"severity": {"type": "string", "enum": ["low", "medium", "high", "critical"]},
"matched_text": {"type": "string"},
"confidence": {"type": "number", "minimum": 0, "maximum": 1},
"message_preview": {"type": "string"}
}
}
},
"checked": {"type": "boolean"},
"rules_count": {"type": "integer"},
"latency_ms": {"type": "integer"}
}
},
description="Admin rule violations and safety checks",
example={
"violations": [
{
"rule_id": "rule1",
"rule_pattern": ".*password.*",
"severity": "high",
"matched_text": "password",
"confidence": 0.95,
"message_preview": "User asked for password"
}
],
"checked": True,
"rules_count": 5,
"latency_ms": 10
}
),
"llm": ToolOutputSchema(
tool_name="llm",
schema={
"type": "object",
"required": ["text", "tokens_used"],
"properties": {
"text": {"type": "string"},
"tokens_used": {"type": "integer"},
"latency_ms": {"type": "integer"},
"model": {"type": "string"},
"temperature": {"type": "number"}
}
},
description="LLM-generated response",
example={
"text": "Generated response text...",
"tokens_used": 150,
"latency_ms": 2000,
"model": "llama3.1:latest",
"temperature": 0.0
}
)
}
def get_tool_latency_estimate(tool_name: str, context: Optional[Dict[str, Any]] = None) -> int:
"""
Get estimated latency for a tool in milliseconds.
Args:
tool_name: Name of the tool (rag, web, admin, llm)
context: Optional context for more accurate estimation
Returns:
Estimated latency in milliseconds
"""
metadata = TOOL_LATENCY_METADATA.get(tool_name)
if not metadata:
# Default estimate for unknown tools
return 1000
return metadata.estimate_latency(context)
def get_tool_schema(tool_name: str) -> Optional[ToolOutputSchema]:
"""Get the output schema for a tool"""
return TOOL_OUTPUT_SCHEMAS.get(tool_name)
def validate_tool_output(tool_name: str, output: Dict[str, Any]) -> tuple[bool, Optional[str]]:
"""
Validate tool output against its schema.
Returns:
(is_valid, error_message)
"""
schema_obj = get_tool_schema(tool_name)
if not schema_obj:
return True, None # Unknown tool, skip validation
# Simple validation (full JSON Schema validation would require jsonschema library)
schema = schema_obj.schema
required = schema.get("required", [])
for field in required:
if field not in output:
return False, f"Missing required field: {field}"
# Type checking for top-level fields
properties = schema.get("properties", {})
for field, value in output.items():
if field in properties:
expected_type = properties[field].get("type")
if expected_type:
if expected_type == "array" and not isinstance(value, list):
return False, f"Field '{field}' must be array, got {type(value).__name__}"
elif expected_type == "object" and not isinstance(value, dict):
return False, f"Field '{field}' must be object, got {type(value).__name__}"
elif expected_type == "string" and not isinstance(value, str):
return False, f"Field '{field}' must be string, got {type(value).__name__}"
elif expected_type == "integer" and not isinstance(value, int):
return False, f"Field '{field}' must be integer, got {type(value).__name__}"
elif expected_type == "number" and not isinstance(value, (int, float)):
return False, f"Field '{field}' must be number, got {type(value).__name__}"
elif expected_type == "boolean" and not isinstance(value, bool):
return False, f"Field '{field}' must be boolean, got {type(value).__name__}"
return True, None
def estimate_path_latency(tool_sequence: List[str], context: Optional[Dict[str, Any]] = None) -> int:
"""
Estimate total latency for a sequence of tools.
Args:
tool_sequence: List of tool names in execution order
context: Optional context for each tool
Returns:
Total estimated latency in milliseconds
"""
total = 0
for tool in tool_sequence:
tool_context = context.get(tool, {}) if context else {}
total += get_tool_latency_estimate(tool, tool_context)
return total
def get_fastest_path(
required_tools: List[str],
context: Optional[Dict[str, Any]] = None
) -> List[str]:
"""
Determine the fastest execution order for required tools.
Currently tools are executed sequentially, but this could be extended
to suggest parallel execution for independent tools.
Args:
required_tools: List of required tool names
context: Optional context for latency estimation
Returns:
Optimized tool sequence
"""
# Sort by estimated latency (fastest first)
tool_latencies = [
(tool, get_tool_latency_estimate(tool, context.get(tool, {}) if context else {}))
for tool in required_tools
]
tool_latencies.sort(key=lambda x: x[1])
return [tool for tool, _ in tool_latencies]