Spaces:
Sleeping
Sleeping
File size: 6,166 Bytes
e44e5dd b13e570 b65ef75 e44e5dd b65ef75 e44e5dd b13e570 e44e5dd b13e570 e44e5dd b65ef75 b13e570 e44e5dd b13e570 e44e5dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
from __future__ import annotations
import inspect
import time
from typing import Any, Awaitable, Callable, Mapping, Optional
from .logging import log_tool_usage
from .tenant import TenantContext, TenantValidationError, build_tenant_context
from . import memory
from . import access_control
class ToolValidationError(ValueError):
"""Raised when the caller request payload is invalid."""
class ToolExecutionError(RuntimeError):
"""Raised for unexpected runtime failures."""
class AuthorizationError(ToolValidationError):
"""Raised when the caller request payload lacks required permissions."""
Payload = Mapping[str, Any]
ToolHandler = Callable[[TenantContext, Payload], Awaitable[dict[str, Any]] | dict[str, Any]]
def success_response(
tool_name: str,
context: TenantContext,
data: Any,
latency_ms: int,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
return {
"status": "ok",
"tool": tool_name,
"tenant_id": context.tenant_id,
"latency_ms": latency_ms,
"metadata": metadata or {},
"data": data,
}
def error_response(
tool_name: str,
context: Optional[TenantContext],
error: Exception,
latency_ms: int,
error_type: str = "runtime_error",
) -> dict[str, Any]:
return {
"status": "error",
"tool": tool_name,
"tenant_id": context.tenant_id if context else None,
"latency_ms": latency_ms,
"error_type": error_type,
"message": str(error),
}
async def maybe_await(result: Any) -> Any:
if inspect.isawaitable(result):
return await result
return result
def _truncate(value: Any, max_length: int = 200) -> Any:
if isinstance(value, str) and len(value) > max_length:
return value[: max_length - 3] + "..."
return value
def _trim_payload(payload: Payload) -> dict[str, Any]:
trimmed: dict[str, Any] = {}
for key, value in payload.items():
if key in {"content", "query"} and isinstance(value, str):
trimmed[key] = _truncate(value)
elif isinstance(value, (str, int, float, bool)) or value is None:
trimmed[key] = value
else:
trimmed[key] = "<complex>"
return trimmed
async def execute_tool(
tool_name: str,
payload: Payload,
handler: ToolHandler,
) -> dict[str, Any]:
start = time.perf_counter()
context: Optional[TenantContext] = None
# --- Short-term conversation memory (per session, not per tenant) ---
session_id = memory.extract_session_id(payload)
end_session_flag = bool(
isinstance(payload, Mapping)
and (
payload.get("end_session") is True
or payload.get("endSession") is True
)
)
# Work on a mutable copy when we want to inject memory
mutable_payload: Mapping[str, Any] = payload
if session_id and not end_session_flag:
recent_memory = memory.get_recent(session_id)
# Only inject memory for tools that want to use it
# (handler can choose to ignore this field)
tmp = dict(payload)
tmp["memory"] = recent_memory
mutable_payload = tmp
# --------------------------------------------------------------------
try:
# Tenant context still comes from the original payload
context = build_tenant_context(payload)
# Enforce role-based permissions for sensitive tool actions
required_action = access_control.get_required_action_for_tool(tool_name)
if required_action and not access_control.role_allows(context.role, required_action):
allowed_roles = access_control.describe_allowed_roles(required_action)
raise AuthorizationError(
f"Role '{context.role}' is not permitted to perform '{required_action}'. "
f"Allowed roles: {allowed_roles}."
)
result = await maybe_await(handler(context, mutable_payload))
latency_ms = int((time.perf_counter() - start) * 1000)
# Store tool output in short-term memory unless the session is ending
if session_id and not end_session_flag:
memory.add_entry(session_id, tool_name, result)
elif session_id and end_session_flag:
memory.clear_session(session_id)
log_tool_usage(
tool_name,
context.tenant_id,
success=True,
latency_ms=latency_ms,
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id,
)
return success_response(
tool_name,
context,
result,
latency_ms,
)
except (TenantValidationError, ToolValidationError) as exc:
latency_ms = int((time.perf_counter() - start) * 1000)
log_tool_usage(
tool_name,
context.tenant_id if context else None,
success=False,
latency_ms=latency_ms,
error_message=str(exc),
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id if context else None,
)
return error_response(tool_name, context, exc, latency_ms, "validation_error")
except Exception as exc: # pragma: no cover - safety net
latency_ms = int((time.perf_counter() - start) * 1000)
log_tool_usage(
tool_name,
context.tenant_id if context else None,
success=False,
latency_ms=latency_ms,
error_message=str(exc),
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id if context else None,
)
return error_response(tool_name, context, exc, latency_ms)
def tool_handler(tool_name: str):
"""
Decorator that wires tenant validation, analytics logging, and error handling.
"""
def decorator(func: ToolHandler):
async def wrapper(payload: Payload) -> dict[str, Any]:
return await execute_tool(tool_name, payload, func)
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper
return decorator
|