nothingworry's picture
feat: add RBAC enforcement for MCP tools and API endpoints
b65ef75
raw
history blame
6.17 kB
from __future__ import annotations
import inspect
import time
from typing import Any, Awaitable, Callable, Mapping, Optional
from .logging import log_tool_usage
from .tenant import TenantContext, TenantValidationError, build_tenant_context
from . import memory
from . import access_control
class ToolValidationError(ValueError):
"""Raised when the caller request payload is invalid."""
class ToolExecutionError(RuntimeError):
"""Raised for unexpected runtime failures."""
class AuthorizationError(ToolValidationError):
"""Raised when the caller request payload lacks required permissions."""
Payload = Mapping[str, Any]
ToolHandler = Callable[[TenantContext, Payload], Awaitable[dict[str, Any]] | dict[str, Any]]
def success_response(
tool_name: str,
context: TenantContext,
data: Any,
latency_ms: int,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
return {
"status": "ok",
"tool": tool_name,
"tenant_id": context.tenant_id,
"latency_ms": latency_ms,
"metadata": metadata or {},
"data": data,
}
def error_response(
tool_name: str,
context: Optional[TenantContext],
error: Exception,
latency_ms: int,
error_type: str = "runtime_error",
) -> dict[str, Any]:
return {
"status": "error",
"tool": tool_name,
"tenant_id": context.tenant_id if context else None,
"latency_ms": latency_ms,
"error_type": error_type,
"message": str(error),
}
async def maybe_await(result: Any) -> Any:
if inspect.isawaitable(result):
return await result
return result
def _truncate(value: Any, max_length: int = 200) -> Any:
if isinstance(value, str) and len(value) > max_length:
return value[: max_length - 3] + "..."
return value
def _trim_payload(payload: Payload) -> dict[str, Any]:
trimmed: dict[str, Any] = {}
for key, value in payload.items():
if key in {"content", "query"} and isinstance(value, str):
trimmed[key] = _truncate(value)
elif isinstance(value, (str, int, float, bool)) or value is None:
trimmed[key] = value
else:
trimmed[key] = "<complex>"
return trimmed
async def execute_tool(
tool_name: str,
payload: Payload,
handler: ToolHandler,
) -> dict[str, Any]:
start = time.perf_counter()
context: Optional[TenantContext] = None
# --- Short-term conversation memory (per session, not per tenant) ---
session_id = memory.extract_session_id(payload)
end_session_flag = bool(
isinstance(payload, Mapping)
and (
payload.get("end_session") is True
or payload.get("endSession") is True
)
)
# Work on a mutable copy when we want to inject memory
mutable_payload: Mapping[str, Any] = payload
if session_id and not end_session_flag:
recent_memory = memory.get_recent(session_id)
# Only inject memory for tools that want to use it
# (handler can choose to ignore this field)
tmp = dict(payload)
tmp["memory"] = recent_memory
mutable_payload = tmp
# --------------------------------------------------------------------
try:
# Tenant context still comes from the original payload
context = build_tenant_context(payload)
# Enforce role-based permissions for sensitive tool actions
required_action = access_control.get_required_action_for_tool(tool_name)
if required_action and not access_control.role_allows(context.role, required_action):
allowed_roles = access_control.describe_allowed_roles(required_action)
raise AuthorizationError(
f"Role '{context.role}' is not permitted to perform '{required_action}'. "
f"Allowed roles: {allowed_roles}."
)
result = await maybe_await(handler(context, mutable_payload))
latency_ms = int((time.perf_counter() - start) * 1000)
# Store tool output in short-term memory unless the session is ending
if session_id and not end_session_flag:
memory.add_entry(session_id, tool_name, result)
elif session_id and end_session_flag:
memory.clear_session(session_id)
log_tool_usage(
tool_name,
context.tenant_id,
success=True,
latency_ms=latency_ms,
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id,
)
return success_response(
tool_name,
context,
result,
latency_ms,
)
except (TenantValidationError, ToolValidationError) as exc:
latency_ms = int((time.perf_counter() - start) * 1000)
log_tool_usage(
tool_name,
context.tenant_id if context else None,
success=False,
latency_ms=latency_ms,
error_message=str(exc),
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id if context else None,
)
return error_response(tool_name, context, exc, latency_ms, "validation_error")
except Exception as exc: # pragma: no cover - safety net
latency_ms = int((time.perf_counter() - start) * 1000)
log_tool_usage(
tool_name,
context.tenant_id if context else None,
success=False,
latency_ms=latency_ms,
error_message=str(exc),
metadata={"payload": _trim_payload(payload)},
user_id=context.user_id if context else None,
)
return error_response(tool_name, context, exc, latency_ms)
def tool_handler(tool_name: str):
"""
Decorator that wires tenant validation, analytics logging, and error handling.
"""
def decorator(func: ToolHandler):
async def wrapper(payload: Payload) -> dict[str, Any]:
return await execute_tool(tool_name, payload, func)
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper
return decorator