from __future__ import annotations import inspect import time from typing import Any, Awaitable, Callable, Mapping, Optional from .logging import log_tool_usage from .tenant import TenantContext, TenantValidationError, build_tenant_context class ToolValidationError(ValueError): """Raised when the caller request payload is invalid.""" class ToolExecutionError(RuntimeError): """Raised for unexpected runtime failures.""" Payload = Mapping[str, Any] ToolHandler = Callable[[TenantContext, Payload], Awaitable[dict[str, Any]] | dict[str, Any]] def success_response( tool_name: str, context: TenantContext, data: Any, latency_ms: int, metadata: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: return { "status": "ok", "tool": tool_name, "tenant_id": context.tenant_id, "latency_ms": latency_ms, "metadata": metadata or {}, "data": data, } def error_response( tool_name: str, context: Optional[TenantContext], error: Exception, latency_ms: int, error_type: str = "runtime_error", ) -> dict[str, Any]: return { "status": "error", "tool": tool_name, "tenant_id": context.tenant_id if context else None, "latency_ms": latency_ms, "error_type": error_type, "message": str(error), } async def maybe_await(result: Any) -> Any: if inspect.isawaitable(result): return await result return result def _truncate(value: Any, max_length: int = 200) -> Any: if isinstance(value, str) and len(value) > max_length: return value[: max_length - 3] + "..." return value def _trim_payload(payload: Payload) -> dict[str, Any]: trimmed: dict[str, Any] = {} for key, value in payload.items(): if key in {"content", "query"} and isinstance(value, str): trimmed[key] = _truncate(value) elif isinstance(value, (str, int, float, bool)) or value is None: trimmed[key] = value else: trimmed[key] = "" return trimmed async def execute_tool( tool_name: str, payload: Payload, handler: ToolHandler, ) -> dict[str, Any]: start = time.perf_counter() context: Optional[TenantContext] = None try: context = build_tenant_context(payload) result = await maybe_await(handler(context, payload)) latency_ms = int((time.perf_counter() - start) * 1000) log_tool_usage( tool_name, context.tenant_id, success=True, latency_ms=latency_ms, metadata={"payload": _trim_payload(payload)}, user_id=context.user_id, ) return success_response( tool_name, context, result, latency_ms, ) except (TenantValidationError, ToolValidationError) as exc: latency_ms = int((time.perf_counter() - start) * 1000) log_tool_usage( tool_name, context.tenant_id if context else None, success=False, latency_ms=latency_ms, error_message=str(exc), metadata={"payload": _trim_payload(payload)}, user_id=context.user_id if context else None, ) return error_response(tool_name, context, exc, latency_ms, "validation_error") except Exception as exc: # pragma: no cover - safety net latency_ms = int((time.perf_counter() - start) * 1000) log_tool_usage( tool_name, context.tenant_id if context else None, success=False, latency_ms=latency_ms, error_message=str(exc), metadata={"payload": _trim_payload(payload)}, user_id=context.user_id if context else None, ) return error_response(tool_name, context, exc, latency_ms) def tool_handler(tool_name: str): """ Decorator that wires tenant validation, analytics logging, and error handling. """ def decorator(func: ToolHandler): async def wrapper(payload: Payload) -> dict[str, Any]: return await execute_tool(tool_name, payload, func) wrapper.__name__ = func.__name__ wrapper.__doc__ = func.__doc__ return wrapper return decorator