nothingworry's picture
Update the backend
e44e5dd
raw
history blame
4.31 kB
from __future__ import annotations
import inspect
import time
from typing import Any, Awaitable, Callable, Mapping, Optional
from .logging import log_tool_usage
from .tenant import TenantContext, TenantValidationError, build_tenant_context
class ToolValidationError(ValueError):
"""Raised when the caller request payload is invalid."""
class ToolExecutionError(RuntimeError):
"""Raised for unexpected runtime failures."""
Payload = Mapping[str, Any]
ToolHandler = Callable[[TenantContext, Payload], Awaitable[dict[str, Any]] | dict[str, Any]]
def success_response(
tool_name: str,
context: TenantContext,
data: Any,
latency_ms: int,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
return {
"status": "ok",
"tool": tool_name,
"tenant_id": context.tenant_id,
"latency_ms": latency_ms,
"metadata": metadata or {},
"data": data,
}
def error_response(
tool_name: str,
context: Optional[TenantContext],
error: Exception,
latency_ms: int,
error_type: str = "runtime_error",
) -> dict[str, Any]:
return {
"status": "error",
"tool": tool_name,
"tenant_id": context.tenant_id if context else None,
"latency_ms": latency_ms,
"error_type": error_type,
"message": str(error),
}
async def maybe_await(result: Any) -> Any:
if inspect.isawaitable(result):
return await result
return result
def _truncate(value: Any, max_length: int = 200) -> Any:
if isinstance(value, str) and len(value) > max_length:
return value[: max_length - 3] + "..."
return value
def _trim_payload(payload: Payload) -> dict[str, Any]:
trimmed: dict[str, Any] = {}
for key, value in payload.items():
if key in {"content", "query"} and isinstance(value, str):
trimmed[key] = _truncate(value)
elif isinstance(value, (str, int, float, bool)) or value is None:
trimmed[key] = value
else:
trimmed[key] = "<complex>"
return trimmed
async def execute_tool(
tool_name: str,
payload: Payload,
handler: ToolHandler,
) -> dict[str, Any]:
start = time.perf_counter()
context: Optional[TenantContext] = None
try:
context = build_tenant_context(payload)
result = await maybe_await(handler(context, payload))
latency_ms = int((time.perf_counter() - start) * 1000)
log_tool_usage(
tool_name,
context.tenant_id,
success=True,
latency_ms=latency_ms,
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id,
)
return success_response(
tool_name,
context,
result,
latency_ms,
)
except (TenantValidationError, ToolValidationError) as exc:
latency_ms = int((time.perf_counter() - start) * 1000)
log_tool_usage(
tool_name,
context.tenant_id if context else None,
success=False,
latency_ms=latency_ms,
error_message=str(exc),
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id if context else None,
)
return error_response(tool_name, context, exc, latency_ms, "validation_error")
except Exception as exc: # pragma: no cover - safety net
latency_ms = int((time.perf_counter() - start) * 1000)
log_tool_usage(
tool_name,
context.tenant_id if context else None,
success=False,
latency_ms=latency_ms,
error_message=str(exc),
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id if context else None,
)
return error_response(tool_name, context, exc, latency_ms)
def tool_handler(tool_name: str):
"""
Decorator that wires tenant validation, analytics logging, and error handling.
"""
def decorator(func: ToolHandler):
async def wrapper(payload: Payload) -> dict[str, Any]:
return await execute_tool(tool_name, payload, func)
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper
return decorator