nothingworry commited on
Commit
b13e570
·
1 Parent(s): e3ebaba

feat: Add short-term conversation memory with TTL for MCP tools

Browse files
backend/mcp_server/common/memory.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Mapping, Optional
7
+
8
+ # Short-term memory configuration
9
+ # -------------------------------
10
+ # These environment variables let you tune behavior without code changes:
11
+ # - MCP_MEMORY_MAX_ITEMS: max number of tool outputs to keep per session (default: 10)
12
+ # - MCP_MEMORY_TTL_SECONDS: how long entries live before expiring (default: 900 = 15 minutes)
13
+
14
+ DEFAULT_MAX_ITEMS = int(os.getenv("MCP_MEMORY_MAX_ITEMS", "10"))
15
+ DEFAULT_TTL_SECONDS = int(os.getenv("MCP_MEMORY_TTL_SECONDS", "900"))
16
+
17
+
18
+ @dataclass
19
+ class MemoryEntry:
20
+ ts: float
21
+ tool_name: str
22
+ output: Any
23
+
24
+
25
+ # NOTE: For safety, this store is intentionally **not** keyed by tenant.
26
+ # It is keyed only by a logical session identifier (e.g. chat session ID).
27
+ _MEMORY: Dict[str, List[MemoryEntry]] = {}
28
+
29
+
30
+ def _now() -> float:
31
+ return time.time()
32
+
33
+
34
+ def extract_session_id(payload: Mapping[str, Any]) -> Optional[str]:
35
+ """
36
+ Extract a logical session identifier from the payload.
37
+
38
+ Supported keys (first match wins):
39
+ - \"session_id\"
40
+ - \"sessionId\"
41
+ - \"conversation_id\"
42
+ - \"conversationId\"
43
+
44
+ Returns:
45
+ Normalized session_id string or None if not present.
46
+ """
47
+ for key in ("session_id", "sessionId", "conversation_id", "conversationId"):
48
+ value = payload.get(key)
49
+ if isinstance(value, str):
50
+ value = value.strip()
51
+ if value:
52
+ return value
53
+ return None
54
+
55
+
56
+ def _prune_expired(entries: List[MemoryEntry], ttl_seconds: int) -> List[MemoryEntry]:
57
+ if not entries:
58
+ return entries
59
+ cutoff = _now() - ttl_seconds
60
+ return [e for e in entries if e.ts >= cutoff]
61
+
62
+
63
+ def add_entry(
64
+ session_id: str,
65
+ tool_name: str,
66
+ output: Any,
67
+ max_items: int = DEFAULT_MAX_ITEMS,
68
+ ttl_seconds: int = DEFAULT_TTL_SECONDS,
69
+ ) -> None:
70
+ """
71
+ Store a new tool output in short-term memory for this session.
72
+
73
+ - Keeps only the last `max_items` entries
74
+ - Drops entries older than `ttl_seconds`
75
+ """
76
+ if not session_id:
77
+ return
78
+
79
+ entries = _MEMORY.get(session_id, [])
80
+ entries = _prune_expired(entries, ttl_seconds)
81
+
82
+ entries.append(MemoryEntry(ts=_now(), tool_name=tool_name, output=output))
83
+
84
+ # Enforce bounded size: keep the most recent entries
85
+ if len(entries) > max_items:
86
+ entries = entries[-max_items:]
87
+
88
+ _MEMORY[session_id] = entries
89
+
90
+
91
+ def get_recent(
92
+ session_id: str,
93
+ limit: Optional[int] = None,
94
+ ttl_seconds: int = DEFAULT_TTL_SECONDS,
95
+ ) -> List[Dict[str, Any]]:
96
+ """
97
+ Return recent, non-expired entries for this session.
98
+
99
+ Each entry is a dict:
100
+ {\"tool\": str, \"timestamp\": float, \"output\": Any}
101
+ """
102
+ if not session_id:
103
+ return []
104
+
105
+ entries = _MEMORY.get(session_id, [])
106
+ entries = _prune_expired(entries, ttl_seconds)
107
+ _MEMORY[session_id] = entries # write back pruned list
108
+
109
+ if limit is not None and limit > 0:
110
+ entries = entries[-limit:]
111
+
112
+ return [
113
+ {
114
+ "tool": e.tool_name,
115
+ "timestamp": e.ts,
116
+ "output": e.output,
117
+ }
118
+ for e in entries
119
+ ]
120
+
121
+
122
+ def clear_session(session_id: str) -> None:
123
+ """
124
+ Explicitly clear all short-term memory for a session.
125
+ Useful when a chat session ends.
126
+ """
127
+ if session_id in _MEMORY:
128
+ del _MEMORY[session_id]
129
+
130
+
backend/mcp_server/common/utils.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Awaitable, Callable, Mapping, Optional
5
 
6
  from .logging import log_tool_usage
7
  from .tenant import TenantContext, TenantValidationError, build_tenant_context
 
8
 
9
 
10
  class ToolValidationError(ValueError):
@@ -84,11 +85,40 @@ async def execute_tool(
84
  ) -> dict[str, Any]:
85
  start = time.perf_counter()
86
  context: Optional[TenantContext] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  try:
 
88
  context = build_tenant_context(payload)
89
- result = await maybe_await(handler(context, payload))
90
  latency_ms = int((time.perf_counter() - start) * 1000)
91
 
 
 
 
 
 
 
92
  log_tool_usage(
93
  tool_name,
94
  context.tenant_id,
 
5
 
6
  from .logging import log_tool_usage
7
  from .tenant import TenantContext, TenantValidationError, build_tenant_context
8
+ from . import memory
9
 
10
 
11
  class ToolValidationError(ValueError):
 
85
  ) -> dict[str, Any]:
86
  start = time.perf_counter()
87
  context: Optional[TenantContext] = None
88
+
89
+ # --- Short-term conversation memory (per session, not per tenant) ---
90
+ session_id = memory.extract_session_id(payload)
91
+ end_session_flag = bool(
92
+ isinstance(payload, Mapping)
93
+ and (
94
+ payload.get("end_session") is True
95
+ or payload.get("endSession") is True
96
+ )
97
+ )
98
+
99
+ # Work on a mutable copy when we want to inject memory
100
+ mutable_payload: Mapping[str, Any] = payload
101
+ if session_id and not end_session_flag:
102
+ recent_memory = memory.get_recent(session_id)
103
+ # Only inject memory for tools that want to use it
104
+ # (handler can choose to ignore this field)
105
+ tmp = dict(payload)
106
+ tmp["memory"] = recent_memory
107
+ mutable_payload = tmp
108
+ # --------------------------------------------------------------------
109
+
110
  try:
111
+ # Tenant context still comes from the original payload
112
  context = build_tenant_context(payload)
113
+ result = await maybe_await(handler(context, mutable_payload))
114
  latency_ms = int((time.perf_counter() - start) * 1000)
115
 
116
+ # Store tool output in short-term memory unless the session is ending
117
+ if session_id and not end_session_flag:
118
+ memory.add_entry(session_id, tool_name, result)
119
+ elif session_id and end_session_flag:
120
+ memory.clear_session(session_id)
121
+
122
  log_tool_usage(
123
  tool_name,
124
  context.tenant_id,
backend/tests/test_conversation_memory.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================
2
+ # File: backend/tests/test_conversation_memory.py
3
+ # =============================================================
4
+ """
5
+ Comprehensive tests for short-term conversation memory with expiration.
6
+
7
+ Tests:
8
+ 1. Memory storage and retrieval
9
+ 2. Memory injection into tool payloads
10
+ 3. Session isolation (different session_ids don't share memory)
11
+ 4. Memory expiration (TTL)
12
+ 5. Memory bounded size (only last N items)
13
+ 6. Session clearing (end_session flag)
14
+ 7. Memory is NOT keyed by tenant_id (same session_id across tenants shares memory)
15
+ """
16
+
17
+ import sys
18
+ from pathlib import Path
19
+ import pytest
20
+ import time
21
+ from unittest.mock import AsyncMock, MagicMock, patch
22
+ import asyncio
23
+
24
+ # Add backend directory to Python path
25
+ backend_dir = Path(__file__).parent.parent
26
+ sys.path.insert(0, str(backend_dir))
27
+
28
+ from mcp_server.common import memory
29
+ from mcp_server.common.utils import execute_tool, ToolHandler
30
+ from mcp_server.common.tenant import TenantContext
31
+
32
+
33
+ # =============================================================
34
+ # FIXTURES
35
+ # =============================================================
36
+
37
+ @pytest.fixture(autouse=True)
38
+ def clear_memory():
39
+ """Clear memory before and after each test."""
40
+ # Clear all memory before test
41
+ memory._MEMORY.clear()
42
+ yield
43
+ # Clear all memory after test
44
+ memory._MEMORY.clear()
45
+
46
+
47
+ @pytest.fixture
48
+ def mock_tool_handler():
49
+ """Create a mock tool handler that captures the payload."""
50
+ captured_payloads = []
51
+
52
+ async def handler(context: TenantContext, payload: dict) -> dict:
53
+ captured_payloads.append(payload)
54
+ return {"result": "success", "tool_output": "test_data"}
55
+
56
+ handler.captured = captured_payloads
57
+ return handler
58
+
59
+
60
+ # =============================================================
61
+ # UNIT TESTS: Memory Module
62
+ # =============================================================
63
+
64
+ def test_extract_session_id():
65
+ """Test session ID extraction from payload."""
66
+ # Test various key formats
67
+ assert memory.extract_session_id({"session_id": "s1"}) == "s1"
68
+ assert memory.extract_session_id({"sessionId": "s2"}) == "s2"
69
+ assert memory.extract_session_id({"conversation_id": "s3"}) == "s3"
70
+ assert memory.extract_session_id({"conversationId": "s4"}) == "s4"
71
+
72
+ # Test first match wins
73
+ assert memory.extract_session_id({
74
+ "session_id": "s1",
75
+ "sessionId": "s2"
76
+ }) == "s1"
77
+
78
+ # Test missing session ID
79
+ assert memory.extract_session_id({"tenant_id": "t1"}) is None
80
+ assert memory.extract_session_id({}) is None
81
+
82
+ # Test empty string
83
+ assert memory.extract_session_id({"session_id": ""}) is None
84
+ assert memory.extract_session_id({"session_id": " "}) is None
85
+
86
+
87
+ def test_add_and_get_entry():
88
+ """Test basic memory storage and retrieval."""
89
+ session_id = "test-session-1"
90
+
91
+ # Add entries
92
+ memory.add_entry(session_id, "tool1", {"output": "data1"}, max_items=10, ttl_seconds=900)
93
+ memory.add_entry(session_id, "tool2", {"output": "data2"}, max_items=10, ttl_seconds=900)
94
+ memory.add_entry(session_id, "tool3", {"output": "data3"}, max_items=10, ttl_seconds=900)
95
+
96
+ # Retrieve entries
97
+ entries = memory.get_recent(session_id, ttl_seconds=900)
98
+
99
+ assert len(entries) == 3
100
+ assert entries[0]["tool"] == "tool1"
101
+ assert entries[1]["tool"] == "tool2"
102
+ assert entries[2]["tool"] == "tool3"
103
+ assert entries[0]["output"] == {"output": "data1"}
104
+ assert "timestamp" in entries[0]
105
+
106
+
107
+ def test_memory_bounded_size():
108
+ """Test that memory only keeps last N items."""
109
+ session_id = "test-session-2"
110
+ max_items = 3
111
+
112
+ # Add more items than max
113
+ for i in range(5):
114
+ memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=max_items, ttl_seconds=900)
115
+
116
+ entries = memory.get_recent(session_id, ttl_seconds=900)
117
+
118
+ # Should only have last 3 items
119
+ assert len(entries) == 3
120
+ assert entries[0]["tool"] == "tool2"
121
+ assert entries[1]["tool"] == "tool3"
122
+ assert entries[2]["tool"] == "tool4"
123
+
124
+
125
+ def test_memory_expiration():
126
+ """Test that expired entries are automatically removed."""
127
+ session_id = "test-session-3"
128
+ short_ttl = 1 # 1 second TTL
129
+
130
+ # Add entry
131
+ memory.add_entry(session_id, "tool1", {"data": "old"}, max_items=10, ttl_seconds=short_ttl)
132
+
133
+ # Should be present immediately
134
+ entries = memory.get_recent(session_id, ttl_seconds=short_ttl)
135
+ assert len(entries) == 1
136
+
137
+ # Wait for expiration
138
+ time.sleep(1.1)
139
+
140
+ # Should be expired now
141
+ entries = memory.get_recent(session_id, ttl_seconds=short_ttl)
142
+ assert len(entries) == 0
143
+
144
+
145
+ def test_session_isolation():
146
+ """Test that different session_ids don't share memory."""
147
+ session1 = "session-1"
148
+ session2 = "session-2"
149
+
150
+ memory.add_entry(session1, "tool1", {"data": "s1"}, max_items=10, ttl_seconds=900)
151
+ memory.add_entry(session2, "tool2", {"data": "s2"}, max_items=10, ttl_seconds=900)
152
+
153
+ entries1 = memory.get_recent(session1, ttl_seconds=900)
154
+ entries2 = memory.get_recent(session2, ttl_seconds=900)
155
+
156
+ assert len(entries1) == 1
157
+ assert len(entries2) == 1
158
+ assert entries1[0]["tool"] == "tool1"
159
+ assert entries2[0]["tool"] == "tool2"
160
+
161
+
162
+ def test_clear_session():
163
+ """Test that clear_session removes all memory for a session."""
164
+ session_id = "test-session-4"
165
+
166
+ memory.add_entry(session_id, "tool1", {"data": "d1"}, max_items=10, ttl_seconds=900)
167
+ memory.add_entry(session_id, "tool2", {"data": "d2"}, max_items=10, ttl_seconds=900)
168
+
169
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 2
170
+
171
+ memory.clear_session(session_id)
172
+
173
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
174
+
175
+
176
+ def test_memory_not_keyed_by_tenant():
177
+ """Test that memory is keyed by session_id, NOT tenant_id."""
178
+ session_id = "shared-session"
179
+ tenant1 = "tenant-a"
180
+ tenant2 = "tenant-b"
181
+
182
+ # Simulate: tenant1 calls tool, then tenant2 calls tool with same session_id
183
+ # They should see each other's tool outputs (because memory is session-based, not tenant-based)
184
+
185
+ # This is intentional for safety - memory is NOT per-tenant
186
+ # In a real scenario, you'd want to ensure session_ids are unique per tenant
187
+ # But the memory system itself doesn't enforce this
188
+
189
+ # Add entry from tenant1 perspective
190
+ memory.add_entry(session_id, "tool1", {"tenant": tenant1, "data": "from-tenant1"}, max_items=10, ttl_seconds=900)
191
+
192
+ # Add entry from tenant2 perspective (same session_id)
193
+ memory.add_entry(session_id, "tool2", {"tenant": tenant2, "data": "from-tenant2"}, max_items=10, ttl_seconds=900)
194
+
195
+ # Both should see both entries (because same session_id)
196
+ entries = memory.get_recent(session_id, ttl_seconds=900)
197
+ assert len(entries) == 2
198
+ assert entries[0]["output"]["tenant"] == tenant1
199
+ assert entries[1]["output"]["tenant"] == tenant2
200
+
201
+
202
+ def test_get_recent_with_limit():
203
+ """Test that get_recent respects the limit parameter."""
204
+ session_id = "test-session-5"
205
+
206
+ # Add 5 entries
207
+ for i in range(5):
208
+ memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=10, ttl_seconds=900)
209
+
210
+ # Get all
211
+ all_entries = memory.get_recent(session_id, limit=None, ttl_seconds=900)
212
+ assert len(all_entries) == 5
213
+
214
+ # Get last 2
215
+ recent_2 = memory.get_recent(session_id, limit=2, ttl_seconds=900)
216
+ assert len(recent_2) == 2
217
+ assert recent_2[0]["tool"] == "tool3"
218
+ assert recent_2[1]["tool"] == "tool4"
219
+
220
+
221
+ # =============================================================
222
+ # INTEGRATION TESTS: execute_tool with Memory
223
+ # =============================================================
224
+
225
+ @pytest.mark.asyncio
226
+ async def test_execute_tool_stores_memory(mock_tool_handler):
227
+ """Test that execute_tool stores tool output in memory."""
228
+ payload = {
229
+ "tenant_id": "test-tenant",
230
+ "session_id": "test-session-6",
231
+ "query": "test query"
232
+ }
233
+
234
+ result = await execute_tool("test.tool", payload, mock_tool_handler)
235
+
236
+ # Check that result is successful
237
+ assert result["status"] == "ok"
238
+
239
+ # Check that memory was stored
240
+ entries = memory.get_recent("test-session-6", ttl_seconds=900)
241
+ assert len(entries) == 1
242
+ assert entries[0]["tool"] == "test.tool"
243
+ assert entries[0]["output"] == {"result": "success", "tool_output": "test_data"}
244
+
245
+
246
+ @pytest.mark.asyncio
247
+ async def test_execute_tool_injects_memory(mock_tool_handler):
248
+ """Test that execute_tool injects recent memory into payload."""
249
+ session_id = "test-session-7"
250
+
251
+ # First call - no memory yet
252
+ payload1 = {
253
+ "tenant_id": "test-tenant",
254
+ "session_id": session_id,
255
+ "query": "first query"
256
+ }
257
+
258
+ await execute_tool("tool1", payload1, mock_tool_handler)
259
+
260
+ # Second call - should have memory from first call
261
+ payload2 = {
262
+ "tenant_id": "test-tenant",
263
+ "session_id": session_id,
264
+ "query": "second query"
265
+ }
266
+
267
+ await execute_tool("tool2", payload2, mock_tool_handler)
268
+
269
+ # Check that second call received memory
270
+ assert len(mock_tool_handler.captured) == 2
271
+ second_payload = mock_tool_handler.captured[1]
272
+
273
+ assert "memory" in second_payload
274
+ assert len(second_payload["memory"]) == 1
275
+ assert second_payload["memory"][0]["tool"] == "tool1"
276
+
277
+
278
+ @pytest.mark.asyncio
279
+ async def test_execute_tool_clears_memory_on_end_session(mock_tool_handler):
280
+ """Test that execute_tool clears memory when end_session is True."""
281
+ session_id = "test-session-8"
282
+
283
+ # First call - store memory
284
+ payload1 = {
285
+ "tenant_id": "test-tenant",
286
+ "session_id": session_id,
287
+ "query": "first query"
288
+ }
289
+
290
+ await execute_tool("tool1", payload1, mock_tool_handler)
291
+
292
+ # Verify memory exists
293
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1
294
+
295
+ # Second call with end_session=True
296
+ payload2 = {
297
+ "tenant_id": "test-tenant",
298
+ "session_id": session_id,
299
+ "end_session": True,
300
+ "query": "closing"
301
+ }
302
+
303
+ await execute_tool("tool2", payload2, mock_tool_handler)
304
+
305
+ # Memory should be cleared
306
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
307
+
308
+ # Third call - should have no memory
309
+ payload3 = {
310
+ "tenant_id": "test-tenant",
311
+ "session_id": session_id,
312
+ "query": "new query"
313
+ }
314
+
315
+ await execute_tool("tool3", payload3, mock_tool_handler)
316
+
317
+ # Check that third call received no memory
318
+ third_payload = mock_tool_handler.captured[2]
319
+ assert "memory" in third_payload
320
+ assert len(third_payload["memory"]) == 0
321
+
322
+
323
+ @pytest.mark.asyncio
324
+ async def test_execute_tool_no_memory_without_session_id(mock_tool_handler):
325
+ """Test that execute_tool doesn't store/inject memory if no session_id."""
326
+ payload = {
327
+ "tenant_id": "test-tenant",
328
+ "query": "test query"
329
+ # No session_id
330
+ }
331
+
332
+ await execute_tool("test.tool", payload, mock_tool_handler)
333
+
334
+ # Should not have stored memory
335
+ # (We can't easily check this without session_id, but handler shouldn't have memory field)
336
+ first_payload = mock_tool_handler.captured[0]
337
+ assert "memory" not in first_payload
338
+
339
+
340
+ @pytest.mark.asyncio
341
+ async def test_execute_tool_multi_step_workflow(mock_tool_handler):
342
+ """Test a multi-step workflow where each step sees previous tool outputs."""
343
+ session_id = "test-session-9"
344
+
345
+ # Step 1: RAG search
346
+ payload1 = {
347
+ "tenant_id": "test-tenant",
348
+ "session_id": session_id,
349
+ "query": "search for X"
350
+ }
351
+
352
+ await execute_tool("rag.search", payload1, mock_tool_handler)
353
+
354
+ # Step 2: Web search (should see RAG results in memory)
355
+ payload2 = {
356
+ "tenant_id": "test-tenant",
357
+ "session_id": session_id,
358
+ "query": "search web for Y"
359
+ }
360
+
361
+ await execute_tool("web.search", payload2, mock_tool_handler)
362
+
363
+ # Step 3: LLM synthesis (should see both RAG and Web results)
364
+ payload3 = {
365
+ "tenant_id": "test-tenant",
366
+ "session_id": session_id,
367
+ "query": "synthesize results"
368
+ }
369
+
370
+ await execute_tool("llm.synthesize", payload3, mock_tool_handler)
371
+
372
+ # Verify all steps captured memory
373
+ assert len(mock_tool_handler.captured) == 3
374
+
375
+ # First call has no memory
376
+ assert "memory" not in mock_tool_handler.captured[0] or len(mock_tool_handler.captured[0].get("memory", [])) == 0
377
+
378
+ # Second call has memory from first
379
+ assert len(mock_tool_handler.captured[1].get("memory", [])) == 1
380
+ assert mock_tool_handler.captured[1]["memory"][0]["tool"] == "rag.search"
381
+
382
+ # Third call has memory from both previous calls
383
+ assert len(mock_tool_handler.captured[2].get("memory", [])) == 2
384
+ assert mock_tool_handler.captured[2]["memory"][0]["tool"] == "rag.search"
385
+ assert mock_tool_handler.captured[2]["memory"][1]["tool"] == "web.search"
386
+
387
+
388
+ @pytest.mark.asyncio
389
+ async def test_execute_tool_end_session_variants(mock_tool_handler):
390
+ """Test that both end_session and endSession flags work."""
391
+ session_id = "test-session-10"
392
+
393
+ # Store some memory
394
+ payload1 = {
395
+ "tenant_id": "test-tenant",
396
+ "session_id": session_id,
397
+ "query": "first"
398
+ }
399
+ await execute_tool("tool1", payload1, mock_tool_handler)
400
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1
401
+
402
+ # Test end_session (snake_case)
403
+ payload2 = {
404
+ "tenant_id": "test-tenant",
405
+ "session_id": session_id,
406
+ "end_session": True,
407
+ "query": "end"
408
+ }
409
+ await execute_tool("tool2", payload2, mock_tool_handler)
410
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
411
+
412
+ # Store memory again
413
+ await execute_tool("tool3", payload1, mock_tool_handler)
414
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1
415
+
416
+ # Test endSession (camelCase)
417
+ payload3 = {
418
+ "tenant_id": "test-tenant",
419
+ "session_id": session_id,
420
+ "endSession": True,
421
+ "query": "end"
422
+ }
423
+ await execute_tool("tool4", payload3, mock_tool_handler)
424
+ assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
425
+
426
+
427
+ # =============================================================
428
+ # EDGE CASES
429
+ # =============================================================
430
+
431
+ def test_empty_session_id():
432
+ """Test that empty session_id doesn't cause errors."""
433
+ memory.add_entry("", "tool1", {"data": "test"}, max_items=10, ttl_seconds=900)
434
+ # Should not store anything
435
+ assert len(memory.get_recent("", ttl_seconds=900)) == 0
436
+
437
+
438
+ def test_none_session_id():
439
+ """Test that None session_id doesn't cause errors."""
440
+ # This shouldn't happen in practice, but test for safety
441
+ entries = memory.get_recent(None, ttl_seconds=900) # type: ignore
442
+ assert entries == []
443
+
444
+
445
+ @pytest.mark.asyncio
446
+ async def test_concurrent_sessions(mock_tool_handler):
447
+ """Test that concurrent sessions don't interfere with each other."""
448
+ session1 = "session-concurrent-1"
449
+ session2 = "session-concurrent-2"
450
+
451
+ # Execute tools in both sessions concurrently
452
+ tasks = [
453
+ execute_tool("tool1", {
454
+ "tenant_id": "tenant1",
455
+ "session_id": session1,
456
+ "query": "q1"
457
+ }, mock_tool_handler),
458
+ execute_tool("tool2", {
459
+ "tenant_id": "tenant2",
460
+ "session_id": session2,
461
+ "query": "q2"
462
+ }, mock_tool_handler),
463
+ ]
464
+
465
+ await asyncio.gather(*tasks)
466
+
467
+ # Each session should have its own memory
468
+ entries1 = memory.get_recent(session1, ttl_seconds=900)
469
+ entries2 = memory.get_recent(session2, ttl_seconds=900)
470
+
471
+ assert len(entries1) == 1
472
+ assert len(entries2) == 1
473
+ assert entries1[0]["tool"] == "tool1"
474
+ assert entries2[0]["tool"] == "tool2"
475
+
476
+
477
+ if __name__ == "__main__":
478
+ pytest.main([__file__, "-v"])
479
+