# ============================================================= # File: backend/tests/test_conversation_memory.py # ============================================================= """ Comprehensive tests for short-term conversation memory with expiration. Tests: 1. Memory storage and retrieval 2. Memory injection into tool payloads 3. Session isolation (different session_ids don't share memory) 4. Memory expiration (TTL) 5. Memory bounded size (only last N items) 6. Session clearing (end_session flag) 7. Memory is NOT keyed by tenant_id (same session_id across tenants shares memory) """ import sys from pathlib import Path import pytest import time from unittest.mock import AsyncMock, MagicMock, patch import asyncio # Add backend directory to Python path backend_dir = Path(__file__).parent.parent sys.path.insert(0, str(backend_dir)) from mcp_server.common import memory from mcp_server.common.utils import execute_tool, ToolHandler from mcp_server.common.tenant import TenantContext # ============================================================= # FIXTURES # ============================================================= @pytest.fixture(autouse=True) def clear_memory(): """Clear memory before and after each test.""" # Clear all memory before test memory._MEMORY.clear() yield # Clear all memory after test memory._MEMORY.clear() @pytest.fixture def mock_tool_handler(): """Create a mock tool handler that captures the payload.""" captured_payloads = [] async def handler(context: TenantContext, payload: dict) -> dict: captured_payloads.append(payload) return {"result": "success", "tool_output": "test_data"} handler.captured = captured_payloads return handler # ============================================================= # UNIT TESTS: Memory Module # ============================================================= def test_extract_session_id(): """Test session ID extraction from payload.""" # Test various key formats assert memory.extract_session_id({"session_id": "s1"}) == "s1" assert memory.extract_session_id({"sessionId": "s2"}) == "s2" assert memory.extract_session_id({"conversation_id": "s3"}) == "s3" assert memory.extract_session_id({"conversationId": "s4"}) == "s4" # Test first match wins assert memory.extract_session_id({ "session_id": "s1", "sessionId": "s2" }) == "s1" # Test missing session ID assert memory.extract_session_id({"tenant_id": "t1"}) is None assert memory.extract_session_id({}) is None # Test empty string assert memory.extract_session_id({"session_id": ""}) is None assert memory.extract_session_id({"session_id": " "}) is None def test_add_and_get_entry(): """Test basic memory storage and retrieval.""" session_id = "test-session-1" # Add entries memory.add_entry(session_id, "tool1", {"output": "data1"}, max_items=10, ttl_seconds=900) memory.add_entry(session_id, "tool2", {"output": "data2"}, max_items=10, ttl_seconds=900) memory.add_entry(session_id, "tool3", {"output": "data3"}, max_items=10, ttl_seconds=900) # Retrieve entries entries = memory.get_recent(session_id, ttl_seconds=900) assert len(entries) == 3 assert entries[0]["tool"] == "tool1" assert entries[1]["tool"] == "tool2" assert entries[2]["tool"] == "tool3" assert entries[0]["output"] == {"output": "data1"} assert "timestamp" in entries[0] def test_memory_bounded_size(): """Test that memory only keeps last N items.""" session_id = "test-session-2" max_items = 3 # Add more items than max for i in range(5): memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=max_items, ttl_seconds=900) entries = memory.get_recent(session_id, ttl_seconds=900) # Should only have last 3 items assert len(entries) == 3 assert entries[0]["tool"] == "tool2" assert entries[1]["tool"] == "tool3" assert entries[2]["tool"] == "tool4" def test_memory_expiration(): """Test that expired entries are automatically removed.""" session_id = "test-session-3" short_ttl = 1 # 1 second TTL # Add entry memory.add_entry(session_id, "tool1", {"data": "old"}, max_items=10, ttl_seconds=short_ttl) # Should be present immediately entries = memory.get_recent(session_id, ttl_seconds=short_ttl) assert len(entries) == 1 # Wait for expiration time.sleep(1.1) # Should be expired now entries = memory.get_recent(session_id, ttl_seconds=short_ttl) assert len(entries) == 0 def test_session_isolation(): """Test that different session_ids don't share memory.""" session1 = "session-1" session2 = "session-2" memory.add_entry(session1, "tool1", {"data": "s1"}, max_items=10, ttl_seconds=900) memory.add_entry(session2, "tool2", {"data": "s2"}, max_items=10, ttl_seconds=900) entries1 = memory.get_recent(session1, ttl_seconds=900) entries2 = memory.get_recent(session2, ttl_seconds=900) assert len(entries1) == 1 assert len(entries2) == 1 assert entries1[0]["tool"] == "tool1" assert entries2[0]["tool"] == "tool2" def test_clear_session(): """Test that clear_session removes all memory for a session.""" session_id = "test-session-4" memory.add_entry(session_id, "tool1", {"data": "d1"}, max_items=10, ttl_seconds=900) memory.add_entry(session_id, "tool2", {"data": "d2"}, max_items=10, ttl_seconds=900) assert len(memory.get_recent(session_id, ttl_seconds=900)) == 2 memory.clear_session(session_id) assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 def test_memory_not_keyed_by_tenant(): """Test that memory is keyed by session_id, NOT tenant_id.""" session_id = "shared-session" tenant1 = "tenant-a" tenant2 = "tenant-b" # Simulate: tenant1 calls tool, then tenant2 calls tool with same session_id # They should see each other's tool outputs (because memory is session-based, not tenant-based) # This is intentional for safety - memory is NOT per-tenant # In a real scenario, you'd want to ensure session_ids are unique per tenant # But the memory system itself doesn't enforce this # Add entry from tenant1 perspective memory.add_entry(session_id, "tool1", {"tenant": tenant1, "data": "from-tenant1"}, max_items=10, ttl_seconds=900) # Add entry from tenant2 perspective (same session_id) memory.add_entry(session_id, "tool2", {"tenant": tenant2, "data": "from-tenant2"}, max_items=10, ttl_seconds=900) # Both should see both entries (because same session_id) entries = memory.get_recent(session_id, ttl_seconds=900) assert len(entries) == 2 assert entries[0]["output"]["tenant"] == tenant1 assert entries[1]["output"]["tenant"] == tenant2 def test_get_recent_with_limit(): """Test that get_recent respects the limit parameter.""" session_id = "test-session-5" # Add 5 entries for i in range(5): memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=10, ttl_seconds=900) # Get all all_entries = memory.get_recent(session_id, limit=None, ttl_seconds=900) assert len(all_entries) == 5 # Get last 2 recent_2 = memory.get_recent(session_id, limit=2, ttl_seconds=900) assert len(recent_2) == 2 assert recent_2[0]["tool"] == "tool3" assert recent_2[1]["tool"] == "tool4" # ============================================================= # INTEGRATION TESTS: execute_tool with Memory # ============================================================= @pytest.mark.asyncio async def test_execute_tool_stores_memory(mock_tool_handler): """Test that execute_tool stores tool output in memory.""" payload = { "tenant_id": "test-tenant", "session_id": "test-session-6", "query": "test query" } result = await execute_tool("test.tool", payload, mock_tool_handler) # Check that result is successful assert result["status"] == "ok" # Check that memory was stored entries = memory.get_recent("test-session-6", ttl_seconds=900) assert len(entries) == 1 assert entries[0]["tool"] == "test.tool" assert entries[0]["output"] == {"result": "success", "tool_output": "test_data"} @pytest.mark.asyncio async def test_execute_tool_injects_memory(mock_tool_handler): """Test that execute_tool injects recent memory into payload.""" session_id = "test-session-7" # First call - no memory yet payload1 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "first query" } await execute_tool("tool1", payload1, mock_tool_handler) # Second call - should have memory from first call payload2 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "second query" } await execute_tool("tool2", payload2, mock_tool_handler) # Check that second call received memory assert len(mock_tool_handler.captured) == 2 second_payload = mock_tool_handler.captured[1] assert "memory" in second_payload assert len(second_payload["memory"]) == 1 assert second_payload["memory"][0]["tool"] == "tool1" @pytest.mark.asyncio async def test_execute_tool_clears_memory_on_end_session(mock_tool_handler): """Test that execute_tool clears memory when end_session is True.""" session_id = "test-session-8" # First call - store memory payload1 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "first query" } await execute_tool("tool1", payload1, mock_tool_handler) # Verify memory exists assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1 # Second call with end_session=True payload2 = { "tenant_id": "test-tenant", "session_id": session_id, "end_session": True, "query": "closing" } await execute_tool("tool2", payload2, mock_tool_handler) # Memory should be cleared assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 # Third call - should have no memory payload3 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "new query" } await execute_tool("tool3", payload3, mock_tool_handler) # Check that third call received no memory third_payload = mock_tool_handler.captured[2] assert "memory" in third_payload assert len(third_payload["memory"]) == 0 @pytest.mark.asyncio async def test_execute_tool_no_memory_without_session_id(mock_tool_handler): """Test that execute_tool doesn't store/inject memory if no session_id.""" payload = { "tenant_id": "test-tenant", "query": "test query" # No session_id } await execute_tool("test.tool", payload, mock_tool_handler) # Should not have stored memory # (We can't easily check this without session_id, but handler shouldn't have memory field) first_payload = mock_tool_handler.captured[0] assert "memory" not in first_payload @pytest.mark.asyncio async def test_execute_tool_multi_step_workflow(mock_tool_handler): """Test a multi-step workflow where each step sees previous tool outputs.""" session_id = "test-session-9" # Step 1: RAG search payload1 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "search for X" } await execute_tool("rag.search", payload1, mock_tool_handler) # Step 2: Web search (should see RAG results in memory) payload2 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "search web for Y" } await execute_tool("web.search", payload2, mock_tool_handler) # Step 3: LLM synthesis (should see both RAG and Web results) payload3 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "synthesize results" } await execute_tool("llm.synthesize", payload3, mock_tool_handler) # Verify all steps captured memory assert len(mock_tool_handler.captured) == 3 # First call has no memory assert "memory" not in mock_tool_handler.captured[0] or len(mock_tool_handler.captured[0].get("memory", [])) == 0 # Second call has memory from first assert len(mock_tool_handler.captured[1].get("memory", [])) == 1 assert mock_tool_handler.captured[1]["memory"][0]["tool"] == "rag.search" # Third call has memory from both previous calls assert len(mock_tool_handler.captured[2].get("memory", [])) == 2 assert mock_tool_handler.captured[2]["memory"][0]["tool"] == "rag.search" assert mock_tool_handler.captured[2]["memory"][1]["tool"] == "web.search" @pytest.mark.asyncio async def test_execute_tool_end_session_variants(mock_tool_handler): """Test that both end_session and endSession flags work.""" session_id = "test-session-10" # Store some memory payload1 = { "tenant_id": "test-tenant", "session_id": session_id, "query": "first" } await execute_tool("tool1", payload1, mock_tool_handler) assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1 # Test end_session (snake_case) payload2 = { "tenant_id": "test-tenant", "session_id": session_id, "end_session": True, "query": "end" } await execute_tool("tool2", payload2, mock_tool_handler) assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 # Store memory again await execute_tool("tool3", payload1, mock_tool_handler) assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1 # Test endSession (camelCase) payload3 = { "tenant_id": "test-tenant", "session_id": session_id, "endSession": True, "query": "end" } await execute_tool("tool4", payload3, mock_tool_handler) assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0 # ============================================================= # EDGE CASES # ============================================================= def test_empty_session_id(): """Test that empty session_id doesn't cause errors.""" memory.add_entry("", "tool1", {"data": "test"}, max_items=10, ttl_seconds=900) # Should not store anything assert len(memory.get_recent("", ttl_seconds=900)) == 0 def test_none_session_id(): """Test that None session_id doesn't cause errors.""" # This shouldn't happen in practice, but test for safety entries = memory.get_recent(None, ttl_seconds=900) # type: ignore assert entries == [] @pytest.mark.asyncio async def test_concurrent_sessions(mock_tool_handler): """Test that concurrent sessions don't interfere with each other.""" session1 = "session-concurrent-1" session2 = "session-concurrent-2" # Execute tools in both sessions concurrently tasks = [ execute_tool("tool1", { "tenant_id": "tenant1", "session_id": session1, "query": "q1" }, mock_tool_handler), execute_tool("tool2", { "tenant_id": "tenant2", "session_id": session2, "query": "q2" }, mock_tool_handler), ] await asyncio.gather(*tasks) # Each session should have its own memory entries1 = memory.get_recent(session1, ttl_seconds=900) entries2 = memory.get_recent(session2, ttl_seconds=900) assert len(entries1) == 1 assert len(entries2) == 1 assert entries1[0]["tool"] == "tool1" assert entries2[0]["tool"] == "tool2" if __name__ == "__main__": pytest.main([__file__, "-v"])