Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import inspect | |
| import time | |
| from typing import Any, Awaitable, Callable, Mapping, Optional | |
| from .logging import log_tool_usage | |
| from .tenant import TenantContext, TenantValidationError, build_tenant_context | |
| from . import memory | |
| from . import access_control | |
| class ToolValidationError(ValueError): | |
| """Raised when the caller request payload is invalid.""" | |
| class ToolExecutionError(RuntimeError): | |
| """Raised for unexpected runtime failures.""" | |
| class AuthorizationError(ToolValidationError): | |
| """Raised when the caller request payload lacks required permissions.""" | |
| Payload = Mapping[str, Any] | |
| ToolHandler = Callable[[TenantContext, Payload], Awaitable[dict[str, Any]] | dict[str, Any]] | |
| def success_response( | |
| tool_name: str, | |
| context: TenantContext, | |
| data: Any, | |
| latency_ms: int, | |
| metadata: Optional[dict[str, Any]] = None, | |
| ) -> dict[str, Any]: | |
| return { | |
| "status": "ok", | |
| "tool": tool_name, | |
| "tenant_id": context.tenant_id, | |
| "latency_ms": latency_ms, | |
| "metadata": metadata or {}, | |
| "data": data, | |
| } | |
| def error_response( | |
| tool_name: str, | |
| context: Optional[TenantContext], | |
| error: Exception, | |
| latency_ms: int, | |
| error_type: str = "runtime_error", | |
| ) -> dict[str, Any]: | |
| return { | |
| "status": "error", | |
| "tool": tool_name, | |
| "tenant_id": context.tenant_id if context else None, | |
| "latency_ms": latency_ms, | |
| "error_type": error_type, | |
| "message": str(error), | |
| } | |
| async def maybe_await(result: Any) -> Any: | |
| if inspect.isawaitable(result): | |
| return await result | |
| return result | |
| def _truncate(value: Any, max_length: int = 200) -> Any: | |
| if isinstance(value, str) and len(value) > max_length: | |
| return value[: max_length - 3] + "..." | |
| return value | |
| def _trim_payload(payload: Payload) -> dict[str, Any]: | |
| trimmed: dict[str, Any] = {} | |
| for key, value in payload.items(): | |
| if key in {"content", "query"} and isinstance(value, str): | |
| trimmed[key] = _truncate(value) | |
| elif isinstance(value, (str, int, float, bool)) or value is None: | |
| trimmed[key] = value | |
| else: | |
| trimmed[key] = "<complex>" | |
| return trimmed | |
| async def execute_tool( | |
| tool_name: str, | |
| payload: Payload, | |
| handler: ToolHandler, | |
| ) -> dict[str, Any]: | |
| start = time.perf_counter() | |
| context: Optional[TenantContext] = None | |
| # --- Short-term conversation memory (per session, not per tenant) --- | |
| session_id = memory.extract_session_id(payload) | |
| end_session_flag = bool( | |
| isinstance(payload, Mapping) | |
| and ( | |
| payload.get("end_session") is True | |
| or payload.get("endSession") is True | |
| ) | |
| ) | |
| # Work on a mutable copy when we want to inject memory | |
| mutable_payload: Mapping[str, Any] = payload | |
| if session_id and not end_session_flag: | |
| recent_memory = memory.get_recent(session_id) | |
| # Only inject memory for tools that want to use it | |
| # (handler can choose to ignore this field) | |
| tmp = dict(payload) | |
| tmp["memory"] = recent_memory | |
| mutable_payload = tmp | |
| # -------------------------------------------------------------------- | |
| try: | |
| # Tenant context still comes from the original payload | |
| context = build_tenant_context(payload) | |
| # Enforce role-based permissions for sensitive tool actions | |
| required_action = access_control.get_required_action_for_tool(tool_name) | |
| if required_action and not access_control.role_allows(context.role, required_action): | |
| allowed_roles = access_control.describe_allowed_roles(required_action) | |
| raise AuthorizationError( | |
| f"Role '{context.role}' is not permitted to perform '{required_action}'. " | |
| f"Allowed roles: {allowed_roles}." | |
| ) | |
| result = await maybe_await(handler(context, mutable_payload)) | |
| latency_ms = int((time.perf_counter() - start) * 1000) | |
| # Store tool output in short-term memory unless the session is ending | |
| if session_id and not end_session_flag: | |
| memory.add_entry(session_id, tool_name, result) | |
| elif session_id and end_session_flag: | |
| memory.clear_session(session_id) | |
| log_tool_usage( | |
| tool_name, | |
| context.tenant_id, | |
| success=True, | |
| latency_ms=latency_ms, | |
| metadata={"payload": _trim_payload(payload)}, | |
| user_id=context.user_id, | |
| ) | |
| return success_response( | |
| tool_name, | |
| context, | |
| result, | |
| latency_ms, | |
| ) | |
| except (TenantValidationError, ToolValidationError) as exc: | |
| latency_ms = int((time.perf_counter() - start) * 1000) | |
| log_tool_usage( | |
| tool_name, | |
| context.tenant_id if context else None, | |
| success=False, | |
| latency_ms=latency_ms, | |
| error_message=str(exc), | |
| metadata={"payload": _trim_payload(payload)}, | |
| user_id=context.user_id if context else None, | |
| ) | |
| return error_response(tool_name, context, exc, latency_ms, "validation_error") | |
| except Exception as exc: # pragma: no cover - safety net | |
| latency_ms = int((time.perf_counter() - start) * 1000) | |
| log_tool_usage( | |
| tool_name, | |
| context.tenant_id if context else None, | |
| success=False, | |
| latency_ms=latency_ms, | |
| error_message=str(exc), | |
| metadata={"payload": _trim_payload(payload)}, | |
| user_id=context.user_id if context else None, | |
| ) | |
| return error_response(tool_name, context, exc, latency_ms) | |
| def tool_handler(tool_name: str): | |
| """ | |
| Decorator that wires tenant validation, analytics logging, and error handling. | |
| """ | |
| def decorator(func: ToolHandler): | |
| async def wrapper(payload: Payload) -> dict[str, Any]: | |
| return await execute_tool(tool_name, payload, func) | |
| wrapper.__name__ = func.__name__ | |
| wrapper.__doc__ = func.__doc__ | |
| return wrapper | |
| return decorator | |