File size: 6,166 Bytes
e44e5dd
 
 
 
 
 
 
b13e570
b65ef75
e44e5dd
 
 
 
 
 
 
 
 
 
b65ef75
 
 
 
e44e5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b13e570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e44e5dd
b13e570
e44e5dd
b65ef75
 
 
 
 
 
 
 
 
 
b13e570
e44e5dd
 
b13e570
 
 
 
 
 
e44e5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from __future__ import annotations
import inspect
import time
from typing import Any, Awaitable, Callable, Mapping, Optional

from .logging import log_tool_usage
from .tenant import TenantContext, TenantValidationError, build_tenant_context
from . import memory
from . import access_control


class ToolValidationError(ValueError):
    """Raised when the caller request payload is invalid."""


class ToolExecutionError(RuntimeError):
    """Raised for unexpected runtime failures."""


class AuthorizationError(ToolValidationError):
    """Raised when the caller request payload lacks required permissions."""


Payload = Mapping[str, Any]
ToolHandler = Callable[[TenantContext, Payload], Awaitable[dict[str, Any]] | dict[str, Any]]


def success_response(
    tool_name: str,
    context: TenantContext,
    data: Any,
    latency_ms: int,
    metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
    return {
        "status": "ok",
        "tool": tool_name,
        "tenant_id": context.tenant_id,
        "latency_ms": latency_ms,
        "metadata": metadata or {},
        "data": data,
    }


def error_response(
    tool_name: str,
    context: Optional[TenantContext],
    error: Exception,
    latency_ms: int,
    error_type: str = "runtime_error",
) -> dict[str, Any]:
    return {
        "status": "error",
        "tool": tool_name,
        "tenant_id": context.tenant_id if context else None,
        "latency_ms": latency_ms,
        "error_type": error_type,
        "message": str(error),
    }


async def maybe_await(result: Any) -> Any:
    if inspect.isawaitable(result):
        return await result
    return result


def _truncate(value: Any, max_length: int = 200) -> Any:
    if isinstance(value, str) and len(value) > max_length:
        return value[: max_length - 3] + "..."
    return value


def _trim_payload(payload: Payload) -> dict[str, Any]:
    trimmed: dict[str, Any] = {}
    for key, value in payload.items():
        if key in {"content", "query"} and isinstance(value, str):
            trimmed[key] = _truncate(value)
        elif isinstance(value, (str, int, float, bool)) or value is None:
            trimmed[key] = value
        else:
            trimmed[key] = "<complex>"
    return trimmed


async def execute_tool(
    tool_name: str,
    payload: Payload,
    handler: ToolHandler,
) -> dict[str, Any]:
    start = time.perf_counter()
    context: Optional[TenantContext] = None

    # --- Short-term conversation memory (per session, not per tenant) ---
    session_id = memory.extract_session_id(payload)
    end_session_flag = bool(
        isinstance(payload, Mapping)
        and (
            payload.get("end_session") is True
            or payload.get("endSession") is True
        )
    )

    # Work on a mutable copy when we want to inject memory
    mutable_payload: Mapping[str, Any] = payload
    if session_id and not end_session_flag:
        recent_memory = memory.get_recent(session_id)
        # Only inject memory for tools that want to use it
        # (handler can choose to ignore this field)
        tmp = dict(payload)
        tmp["memory"] = recent_memory
        mutable_payload = tmp
    # --------------------------------------------------------------------

    try:
        # Tenant context still comes from the original payload
        context = build_tenant_context(payload)

        # Enforce role-based permissions for sensitive tool actions
        required_action = access_control.get_required_action_for_tool(tool_name)
        if required_action and not access_control.role_allows(context.role, required_action):
            allowed_roles = access_control.describe_allowed_roles(required_action)
            raise AuthorizationError(
                f"Role '{context.role}' is not permitted to perform '{required_action}'. "
                f"Allowed roles: {allowed_roles}."
            )

        result = await maybe_await(handler(context, mutable_payload))
        latency_ms = int((time.perf_counter() - start) * 1000)

        # Store tool output in short-term memory unless the session is ending
        if session_id and not end_session_flag:
            memory.add_entry(session_id, tool_name, result)
        elif session_id and end_session_flag:
            memory.clear_session(session_id)

        log_tool_usage(
            tool_name,
            context.tenant_id,
            success=True,
            latency_ms=latency_ms,
            metadata={"payload": _trim_payload(payload)},
            user_id=context.user_id,
        )
        return success_response(
            tool_name,
            context,
            result,
            latency_ms,
        )
    except (TenantValidationError, ToolValidationError) as exc:
        latency_ms = int((time.perf_counter() - start) * 1000)
        log_tool_usage(
            tool_name,
            context.tenant_id if context else None,
            success=False,
            latency_ms=latency_ms,
            error_message=str(exc),
            metadata={"payload": _trim_payload(payload)},
            user_id=context.user_id if context else None,
        )
        return error_response(tool_name, context, exc, latency_ms, "validation_error")
    except Exception as exc:  # pragma: no cover - safety net
        latency_ms = int((time.perf_counter() - start) * 1000)
        log_tool_usage(
            tool_name,
            context.tenant_id if context else None,
            success=False,
            latency_ms=latency_ms,
            error_message=str(exc),
            metadata={"payload": _trim_payload(payload)},
            user_id=context.user_id if context else None,
        )
        return error_response(tool_name, context, exc, latency_ms)


def tool_handler(tool_name: str):
    """
    Decorator that wires tenant validation, analytics logging, and error handling.
    """

    def decorator(func: ToolHandler):
        async def wrapper(payload: Payload) -> dict[str, Any]:
            return await execute_tool(tool_name, payload, func)

        wrapper.__name__ = func.__name__
        wrapper.__doc__ = func.__doc__
        return wrapper

    return decorator