Spaces:
Sleeping
Sleeping
| # ============================================================= | |
| # File: backend/tests/test_conversation_memory.py | |
| # ============================================================= | |
| """ | |
| Comprehensive tests for short-term conversation memory with expiration. | |
| Tests: | |
| 1. Memory storage and retrieval | |
| 2. Memory injection into tool payloads | |
| 3. Session isolation (different session_ids don't share memory) | |
| 4. Memory expiration (TTL) | |
| 5. Memory bounded size (only last N items) | |
| 6. Session clearing (end_session flag) | |
| 7. Memory is NOT keyed by tenant_id (same session_id across tenants shares memory) | |
| """ | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| import time | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import asyncio | |
| # Add backend directory to Python path | |
| backend_dir = Path(__file__).parent.parent | |
| sys.path.insert(0, str(backend_dir)) | |
| from mcp_server.common import memory | |
| from mcp_server.common.utils import execute_tool, ToolHandler | |
| from mcp_server.common.tenant import TenantContext | |
| # ============================================================= | |
| # FIXTURES | |
| # ============================================================= | |
| def clear_memory(): | |
| """Clear memory before and after each test.""" | |
| # Clear all memory before test | |
| memory._MEMORY.clear() | |
| yield | |
| # Clear all memory after test | |
| memory._MEMORY.clear() | |
| def mock_tool_handler(): | |
| """Create a mock tool handler that captures the payload.""" | |
| captured_payloads = [] | |
| async def handler(context: TenantContext, payload: dict) -> dict: | |
| captured_payloads.append(payload) | |
| return {"result": "success", "tool_output": "test_data"} | |
| handler.captured = captured_payloads | |
| return handler | |
| # ============================================================= | |
| # UNIT TESTS: Memory Module | |
| # ============================================================= | |
| def test_extract_session_id(): | |
| """Test session ID extraction from payload.""" | |
| # Test various key formats | |
| assert memory.extract_session_id({"session_id": "s1"}) == "s1" | |
| assert memory.extract_session_id({"sessionId": "s2"}) == "s2" | |
| assert memory.extract_session_id({"conversation_id": "s3"}) == "s3" | |
| assert memory.extract_session_id({"conversationId": "s4"}) == "s4" | |
| # Test first match wins | |
| assert memory.extract_session_id({ | |
| "session_id": "s1", | |
| "sessionId": "s2" | |
| }) == "s1" | |
| # Test missing session ID | |
| assert memory.extract_session_id({"tenant_id": "t1"}) is None | |
| assert memory.extract_session_id({}) is None | |
| # Test empty string | |
| assert memory.extract_session_id({"session_id": ""}) is None | |
| assert memory.extract_session_id({"session_id": " "}) is None | |
| def test_add_and_get_entry(): | |
| """Test basic memory storage and retrieval.""" | |
| session_id = "test-session-1" | |
| # Add entries | |
| memory.add_entry(session_id, "tool1", {"output": "data1"}, max_items=10, ttl_seconds=900) | |
| memory.add_entry(session_id, "tool2", {"output": "data2"}, max_items=10, ttl_seconds=900) | |
| memory.add_entry(session_id, "tool3", {"output": "data3"}, max_items=10, ttl_seconds=900) | |
| # Retrieve entries | |
| entries = memory.get_recent(session_id, ttl_seconds=900) | |
| assert len(entries) == 3 | |
| assert entries[0]["tool"] == "tool1" | |
| assert entries[1]["tool"] == "tool2" | |
| assert entries[2]["tool"] == "tool3" | |
| assert entries[0]["output"] == {"output": "data1"} | |
| assert "timestamp" in entries[0] | |
| def test_memory_bounded_size(): | |
| """Test that memory only keeps last N items.""" | |
| session_id = "test-session-2" | |
| max_items = 3 | |
| # Add more items than max | |
| for i in range(5): | |
| memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=max_items, ttl_seconds=900) | |
| entries = memory.get_recent(session_id, ttl_seconds=900) | |
| # Should only have last 3 items | |
| assert len(entries) == 3 | |
| assert entries[0]["tool"] == "tool2" | |
| assert entries[1]["tool"] == "tool3" | |
| assert entries[2]["tool"] == "tool4" | |
| def test_memory_expiration(): | |
| """Test that expired entries are automatically removed.""" | |
| session_id = "test-session-3" | |
| short_ttl = 1 # 1 second TTL | |
| # Add entry | |
| memory.add_entry(session_id, "tool1", {"data": "old"}, max_items=10, ttl_seconds=short_ttl) | |
| # Should be present immediately | |
| entries = memory.get_recent(session_id, ttl_seconds=short_ttl) | |
| assert len(entries) == 1 | |
| # Wait for expiration | |
| time.sleep(1.1) | |
| # Should be expired now | |
| entries = memory.get_recent(session_id, ttl_seconds=short_ttl) | |
| assert len(entries) == 0 | |
| def test_session_isolation(): | |
| """Test that different session_ids don't share memory.""" | |
| session1 = "session-1" | |
| session2 = "session-2" | |
| memory.add_entry(session1, "tool1", {"data": "s1"}, max_items=10, ttl_seconds=900) | |
| memory.add_entry(session2, "tool2", {"data": "s2"}, max_items=10, ttl_seconds=900) | |
| entries1 = memory.get_recent(session1, ttl_seconds=900) | |
| entries2 = memory.get_recent(session2, ttl_seconds=900) | |
| assert len(entries1) == 1 | |
| assert len(entries2) == 1 | |
| assert entries1[0]["tool"] == "tool1" | |
| assert entries2[0]["tool"] == "tool2" | |
| def test_clear_session(): | |
| """Test that clear_session removes all memory for a session.""" | |
| session_id = "test-session-4" | |
| memory.add_entry(session_id, "tool1", {"data": "d1"}, max_items=10, ttl_seconds=900) | |
| memory.add_entry(session_id, "tool2", {"data": "d2"}, max_items=10, ttl_seconds=900) | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 2 | |
| memory.clear_session(session_id) | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 | |
| def test_memory_not_keyed_by_tenant(): | |
| """Test that memory is keyed by session_id, NOT tenant_id.""" | |
| session_id = "shared-session" | |
| tenant1 = "tenant-a" | |
| tenant2 = "tenant-b" | |
| # Simulate: tenant1 calls tool, then tenant2 calls tool with same session_id | |
| # They should see each other's tool outputs (because memory is session-based, not tenant-based) | |
| # This is intentional for safety - memory is NOT per-tenant | |
| # In a real scenario, you'd want to ensure session_ids are unique per tenant | |
| # But the memory system itself doesn't enforce this | |
| # Add entry from tenant1 perspective | |
| memory.add_entry(session_id, "tool1", {"tenant": tenant1, "data": "from-tenant1"}, max_items=10, ttl_seconds=900) | |
| # Add entry from tenant2 perspective (same session_id) | |
| memory.add_entry(session_id, "tool2", {"tenant": tenant2, "data": "from-tenant2"}, max_items=10, ttl_seconds=900) | |
| # Both should see both entries (because same session_id) | |
| entries = memory.get_recent(session_id, ttl_seconds=900) | |
| assert len(entries) == 2 | |
| assert entries[0]["output"]["tenant"] == tenant1 | |
| assert entries[1]["output"]["tenant"] == tenant2 | |
| def test_get_recent_with_limit(): | |
| """Test that get_recent respects the limit parameter.""" | |
| session_id = "test-session-5" | |
| # Add 5 entries | |
| for i in range(5): | |
| memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=10, ttl_seconds=900) | |
| # Get all | |
| all_entries = memory.get_recent(session_id, limit=None, ttl_seconds=900) | |
| assert len(all_entries) == 5 | |
| # Get last 2 | |
| recent_2 = memory.get_recent(session_id, limit=2, ttl_seconds=900) | |
| assert len(recent_2) == 2 | |
| assert recent_2[0]["tool"] == "tool3" | |
| assert recent_2[1]["tool"] == "tool4" | |
| # ============================================================= | |
| # INTEGRATION TESTS: execute_tool with Memory | |
| # ============================================================= | |
| async def test_execute_tool_stores_memory(mock_tool_handler): | |
| """Test that execute_tool stores tool output in memory.""" | |
| payload = { | |
| "tenant_id": "test-tenant", | |
| "session_id": "test-session-6", | |
| "query": "test query" | |
| } | |
| result = await execute_tool("test.tool", payload, mock_tool_handler) | |
| # Check that result is successful | |
| assert result["status"] == "ok" | |
| # Check that memory was stored | |
| entries = memory.get_recent("test-session-6", ttl_seconds=900) | |
| assert len(entries) == 1 | |
| assert entries[0]["tool"] == "test.tool" | |
| assert entries[0]["output"] == {"result": "success", "tool_output": "test_data"} | |
| async def test_execute_tool_injects_memory(mock_tool_handler): | |
| """Test that execute_tool injects recent memory into payload.""" | |
| session_id = "test-session-7" | |
| # First call - no memory yet | |
| payload1 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "first query" | |
| } | |
| await execute_tool("tool1", payload1, mock_tool_handler) | |
| # Second call - should have memory from first call | |
| payload2 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "second query" | |
| } | |
| await execute_tool("tool2", payload2, mock_tool_handler) | |
| # Check that second call received memory | |
| assert len(mock_tool_handler.captured) == 2 | |
| second_payload = mock_tool_handler.captured[1] | |
| assert "memory" in second_payload | |
| assert len(second_payload["memory"]) == 1 | |
| assert second_payload["memory"][0]["tool"] == "tool1" | |
| async def test_execute_tool_clears_memory_on_end_session(mock_tool_handler): | |
| """Test that execute_tool clears memory when end_session is True.""" | |
| session_id = "test-session-8" | |
| # First call - store memory | |
| payload1 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "first query" | |
| } | |
| await execute_tool("tool1", payload1, mock_tool_handler) | |
| # Verify memory exists | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1 | |
| # Second call with end_session=True | |
| payload2 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "end_session": True, | |
| "query": "closing" | |
| } | |
| await execute_tool("tool2", payload2, mock_tool_handler) | |
| # Memory should be cleared | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 | |
| # Third call - should have no memory | |
| payload3 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "new query" | |
| } | |
| await execute_tool("tool3", payload3, mock_tool_handler) | |
| # Check that third call received no memory | |
| third_payload = mock_tool_handler.captured[2] | |
| assert "memory" in third_payload | |
| assert len(third_payload["memory"]) == 0 | |
| async def test_execute_tool_no_memory_without_session_id(mock_tool_handler): | |
| """Test that execute_tool doesn't store/inject memory if no session_id.""" | |
| payload = { | |
| "tenant_id": "test-tenant", | |
| "query": "test query" | |
| # No session_id | |
| } | |
| await execute_tool("test.tool", payload, mock_tool_handler) | |
| # Should not have stored memory | |
| # (We can't easily check this without session_id, but handler shouldn't have memory field) | |
| first_payload = mock_tool_handler.captured[0] | |
| assert "memory" not in first_payload | |
| async def test_execute_tool_multi_step_workflow(mock_tool_handler): | |
| """Test a multi-step workflow where each step sees previous tool outputs.""" | |
| session_id = "test-session-9" | |
| # Step 1: RAG search | |
| payload1 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "search for X" | |
| } | |
| await execute_tool("rag.search", payload1, mock_tool_handler) | |
| # Step 2: Web search (should see RAG results in memory) | |
| payload2 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "search web for Y" | |
| } | |
| await execute_tool("web.search", payload2, mock_tool_handler) | |
| # Step 3: LLM synthesis (should see both RAG and Web results) | |
| payload3 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "synthesize results" | |
| } | |
| await execute_tool("llm.synthesize", payload3, mock_tool_handler) | |
| # Verify all steps captured memory | |
| assert len(mock_tool_handler.captured) == 3 | |
| # First call has no memory | |
| assert "memory" not in mock_tool_handler.captured[0] or len(mock_tool_handler.captured[0].get("memory", [])) == 0 | |
| # Second call has memory from first | |
| assert len(mock_tool_handler.captured[1].get("memory", [])) == 1 | |
| assert mock_tool_handler.captured[1]["memory"][0]["tool"] == "rag.search" | |
| # Third call has memory from both previous calls | |
| assert len(mock_tool_handler.captured[2].get("memory", [])) == 2 | |
| assert mock_tool_handler.captured[2]["memory"][0]["tool"] == "rag.search" | |
| assert mock_tool_handler.captured[2]["memory"][1]["tool"] == "web.search" | |
| async def test_execute_tool_end_session_variants(mock_tool_handler): | |
| """Test that both end_session and endSession flags work.""" | |
| session_id = "test-session-10" | |
| # Store some memory | |
| payload1 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "query": "first" | |
| } | |
| await execute_tool("tool1", payload1, mock_tool_handler) | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1 | |
| # Test end_session (snake_case) | |
| payload2 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "end_session": True, | |
| "query": "end" | |
| } | |
| await execute_tool("tool2", payload2, mock_tool_handler) | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 | |
| # Store memory again | |
| await execute_tool("tool3", payload1, mock_tool_handler) | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1 | |
| # Test endSession (camelCase) | |
| payload3 = { | |
| "tenant_id": "test-tenant", | |
| "session_id": session_id, | |
| "endSession": True, | |
| "query": "end" | |
| } | |
| await execute_tool("tool4", payload3, mock_tool_handler) | |
| assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 | |
| # ============================================================= | |
| # EDGE CASES | |
| # ============================================================= | |
| def test_empty_session_id(): | |
| """Test that empty session_id doesn't cause errors.""" | |
| memory.add_entry("", "tool1", {"data": "test"}, max_items=10, ttl_seconds=900) | |
| # Should not store anything | |
| assert len(memory.get_recent("", ttl_seconds=900)) == 0 | |
| def test_none_session_id(): | |
| """Test that None session_id doesn't cause errors.""" | |
| # This shouldn't happen in practice, but test for safety | |
| entries = memory.get_recent(None, ttl_seconds=900) # type: ignore | |
| assert entries == [] | |
| async def test_concurrent_sessions(mock_tool_handler): | |
| """Test that concurrent sessions don't interfere with each other.""" | |
| session1 = "session-concurrent-1" | |
| session2 = "session-concurrent-2" | |
| # Execute tools in both sessions concurrently | |
| tasks = [ | |
| execute_tool("tool1", { | |
| "tenant_id": "tenant1", | |
| "session_id": session1, | |
| "query": "q1" | |
| }, mock_tool_handler), | |
| execute_tool("tool2", { | |
| "tenant_id": "tenant2", | |
| "session_id": session2, | |
| "query": "q2" | |
| }, mock_tool_handler), | |
| ] | |
| await asyncio.gather(*tasks) | |
| # Each session should have its own memory | |
| entries1 = memory.get_recent(session1, ttl_seconds=900) | |
| entries2 = memory.get_recent(session2, ttl_seconds=900) | |
| assert len(entries1) == 1 | |
| assert len(entries2) == 1 | |
| assert entries1[0]["tool"] == "tool1" | |
| assert entries2[0]["tool"] == "tool2" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |