IntegraChat / backend /tests /test_conversation_memory.py
nothingworry's picture
feat: Add short-term conversation memory with TTL for MCP tools
b13e570
raw
history blame
16 kB
# =============================================================
# File: backend/tests/test_conversation_memory.py
# =============================================================
"""
Comprehensive tests for short-term conversation memory with expiration.
Tests:
1. Memory storage and retrieval
2. Memory injection into tool payloads
3. Session isolation (different session_ids don't share memory)
4. Memory expiration (TTL)
5. Memory bounded size (only last N items)
6. Session clearing (end_session flag)
7. Memory is NOT keyed by tenant_id (same session_id across tenants shares memory)
"""
import sys
from pathlib import Path
import pytest
import time
from unittest.mock import AsyncMock, MagicMock, patch
import asyncio
# Add backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from mcp_server.common import memory
from mcp_server.common.utils import execute_tool, ToolHandler
from mcp_server.common.tenant import TenantContext
# =============================================================
# FIXTURES
# =============================================================
@pytest.fixture(autouse=True)
def clear_memory():
"""Clear memory before and after each test."""
# Clear all memory before test
memory._MEMORY.clear()
yield
# Clear all memory after test
memory._MEMORY.clear()
@pytest.fixture
def mock_tool_handler():
"""Create a mock tool handler that captures the payload."""
captured_payloads = []
async def handler(context: TenantContext, payload: dict) -> dict:
captured_payloads.append(payload)
return {"result": "success", "tool_output": "test_data"}
handler.captured = captured_payloads
return handler
# =============================================================
# UNIT TESTS: Memory Module
# =============================================================
def test_extract_session_id():
"""Test session ID extraction from payload."""
# Test various key formats
assert memory.extract_session_id({"session_id": "s1"}) == "s1"
assert memory.extract_session_id({"sessionId": "s2"}) == "s2"
assert memory.extract_session_id({"conversation_id": "s3"}) == "s3"
assert memory.extract_session_id({"conversationId": "s4"}) == "s4"
# Test first match wins
assert memory.extract_session_id({
"session_id": "s1",
"sessionId": "s2"
}) == "s1"
# Test missing session ID
assert memory.extract_session_id({"tenant_id": "t1"}) is None
assert memory.extract_session_id({}) is None
# Test empty string
assert memory.extract_session_id({"session_id": ""}) is None
assert memory.extract_session_id({"session_id": " "}) is None
def test_add_and_get_entry():
"""Test basic memory storage and retrieval."""
session_id = "test-session-1"
# Add entries
memory.add_entry(session_id, "tool1", {"output": "data1"}, max_items=10, ttl_seconds=900)
memory.add_entry(session_id, "tool2", {"output": "data2"}, max_items=10, ttl_seconds=900)
memory.add_entry(session_id, "tool3", {"output": "data3"}, max_items=10, ttl_seconds=900)
# Retrieve entries
entries = memory.get_recent(session_id, ttl_seconds=900)
assert len(entries) == 3
assert entries[0]["tool"] == "tool1"
assert entries[1]["tool"] == "tool2"
assert entries[2]["tool"] == "tool3"
assert entries[0]["output"] == {"output": "data1"}
assert "timestamp" in entries[0]
def test_memory_bounded_size():
"""Test that memory only keeps last N items."""
session_id = "test-session-2"
max_items = 3
# Add more items than max
for i in range(5):
memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=max_items, ttl_seconds=900)
entries = memory.get_recent(session_id, ttl_seconds=900)
# Should only have last 3 items
assert len(entries) == 3
assert entries[0]["tool"] == "tool2"
assert entries[1]["tool"] == "tool3"
assert entries[2]["tool"] == "tool4"
def test_memory_expiration():
"""Test that expired entries are automatically removed."""
session_id = "test-session-3"
short_ttl = 1 # 1 second TTL
# Add entry
memory.add_entry(session_id, "tool1", {"data": "old"}, max_items=10, ttl_seconds=short_ttl)
# Should be present immediately
entries = memory.get_recent(session_id, ttl_seconds=short_ttl)
assert len(entries) == 1
# Wait for expiration
time.sleep(1.1)
# Should be expired now
entries = memory.get_recent(session_id, ttl_seconds=short_ttl)
assert len(entries) == 0
def test_session_isolation():
"""Test that different session_ids don't share memory."""
session1 = "session-1"
session2 = "session-2"
memory.add_entry(session1, "tool1", {"data": "s1"}, max_items=10, ttl_seconds=900)
memory.add_entry(session2, "tool2", {"data": "s2"}, max_items=10, ttl_seconds=900)
entries1 = memory.get_recent(session1, ttl_seconds=900)
entries2 = memory.get_recent(session2, ttl_seconds=900)
assert len(entries1) == 1
assert len(entries2) == 1
assert entries1[0]["tool"] == "tool1"
assert entries2[0]["tool"] == "tool2"
def test_clear_session():
"""Test that clear_session removes all memory for a session."""
session_id = "test-session-4"
memory.add_entry(session_id, "tool1", {"data": "d1"}, max_items=10, ttl_seconds=900)
memory.add_entry(session_id, "tool2", {"data": "d2"}, max_items=10, ttl_seconds=900)
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 2
memory.clear_session(session_id)
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
def test_memory_not_keyed_by_tenant():
"""Test that memory is keyed by session_id, NOT tenant_id."""
session_id = "shared-session"
tenant1 = "tenant-a"
tenant2 = "tenant-b"
# Simulate: tenant1 calls tool, then tenant2 calls tool with same session_id
# They should see each other's tool outputs (because memory is session-based, not tenant-based)
# This is intentional for safety - memory is NOT per-tenant
# In a real scenario, you'd want to ensure session_ids are unique per tenant
# But the memory system itself doesn't enforce this
# Add entry from tenant1 perspective
memory.add_entry(session_id, "tool1", {"tenant": tenant1, "data": "from-tenant1"}, max_items=10, ttl_seconds=900)
# Add entry from tenant2 perspective (same session_id)
memory.add_entry(session_id, "tool2", {"tenant": tenant2, "data": "from-tenant2"}, max_items=10, ttl_seconds=900)
# Both should see both entries (because same session_id)
entries = memory.get_recent(session_id, ttl_seconds=900)
assert len(entries) == 2
assert entries[0]["output"]["tenant"] == tenant1
assert entries[1]["output"]["tenant"] == tenant2
def test_get_recent_with_limit():
"""Test that get_recent respects the limit parameter."""
session_id = "test-session-5"
# Add 5 entries
for i in range(5):
memory.add_entry(session_id, f"tool{i}", {"data": i}, max_items=10, ttl_seconds=900)
# Get all
all_entries = memory.get_recent(session_id, limit=None, ttl_seconds=900)
assert len(all_entries) == 5
# Get last 2
recent_2 = memory.get_recent(session_id, limit=2, ttl_seconds=900)
assert len(recent_2) == 2
assert recent_2[0]["tool"] == "tool3"
assert recent_2[1]["tool"] == "tool4"
# =============================================================
# INTEGRATION TESTS: execute_tool with Memory
# =============================================================
@pytest.mark.asyncio
async def test_execute_tool_stores_memory(mock_tool_handler):
"""Test that execute_tool stores tool output in memory."""
payload = {
"tenant_id": "test-tenant",
"session_id": "test-session-6",
"query": "test query"
}
result = await execute_tool("test.tool", payload, mock_tool_handler)
# Check that result is successful
assert result["status"] == "ok"
# Check that memory was stored
entries = memory.get_recent("test-session-6", ttl_seconds=900)
assert len(entries) == 1
assert entries[0]["tool"] == "test.tool"
assert entries[0]["output"] == {"result": "success", "tool_output": "test_data"}
@pytest.mark.asyncio
async def test_execute_tool_injects_memory(mock_tool_handler):
"""Test that execute_tool injects recent memory into payload."""
session_id = "test-session-7"
# First call - no memory yet
payload1 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "first query"
}
await execute_tool("tool1", payload1, mock_tool_handler)
# Second call - should have memory from first call
payload2 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "second query"
}
await execute_tool("tool2", payload2, mock_tool_handler)
# Check that second call received memory
assert len(mock_tool_handler.captured) == 2
second_payload = mock_tool_handler.captured[1]
assert "memory" in second_payload
assert len(second_payload["memory"]) == 1
assert second_payload["memory"][0]["tool"] == "tool1"
@pytest.mark.asyncio
async def test_execute_tool_clears_memory_on_end_session(mock_tool_handler):
"""Test that execute_tool clears memory when end_session is True."""
session_id = "test-session-8"
# First call - store memory
payload1 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "first query"
}
await execute_tool("tool1", payload1, mock_tool_handler)
# Verify memory exists
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1
# Second call with end_session=True
payload2 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"end_session": True,
"query": "closing"
}
await execute_tool("tool2", payload2, mock_tool_handler)
# Memory should be cleared
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
# Third call - should have no memory
payload3 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "new query"
}
await execute_tool("tool3", payload3, mock_tool_handler)
# Check that third call received no memory
third_payload = mock_tool_handler.captured[2]
assert "memory" in third_payload
assert len(third_payload["memory"]) == 0
@pytest.mark.asyncio
async def test_execute_tool_no_memory_without_session_id(mock_tool_handler):
"""Test that execute_tool doesn't store/inject memory if no session_id."""
payload = {
"tenant_id": "test-tenant",
"query": "test query"
# No session_id
}
await execute_tool("test.tool", payload, mock_tool_handler)
# Should not have stored memory
# (We can't easily check this without session_id, but handler shouldn't have memory field)
first_payload = mock_tool_handler.captured[0]
assert "memory" not in first_payload
@pytest.mark.asyncio
async def test_execute_tool_multi_step_workflow(mock_tool_handler):
"""Test a multi-step workflow where each step sees previous tool outputs."""
session_id = "test-session-9"
# Step 1: RAG search
payload1 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "search for X"
}
await execute_tool("rag.search", payload1, mock_tool_handler)
# Step 2: Web search (should see RAG results in memory)
payload2 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "search web for Y"
}
await execute_tool("web.search", payload2, mock_tool_handler)
# Step 3: LLM synthesis (should see both RAG and Web results)
payload3 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "synthesize results"
}
await execute_tool("llm.synthesize", payload3, mock_tool_handler)
# Verify all steps captured memory
assert len(mock_tool_handler.captured) == 3
# First call has no memory
assert "memory" not in mock_tool_handler.captured[0] or len(mock_tool_handler.captured[0].get("memory", [])) == 0
# Second call has memory from first
assert len(mock_tool_handler.captured[1].get("memory", [])) == 1
assert mock_tool_handler.captured[1]["memory"][0]["tool"] == "rag.search"
# Third call has memory from both previous calls
assert len(mock_tool_handler.captured[2].get("memory", [])) == 2
assert mock_tool_handler.captured[2]["memory"][0]["tool"] == "rag.search"
assert mock_tool_handler.captured[2]["memory"][1]["tool"] == "web.search"
@pytest.mark.asyncio
async def test_execute_tool_end_session_variants(mock_tool_handler):
"""Test that both end_session and endSession flags work."""
session_id = "test-session-10"
# Store some memory
payload1 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"query": "first"
}
await execute_tool("tool1", payload1, mock_tool_handler)
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1
# Test end_session (snake_case)
payload2 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"end_session": True,
"query": "end"
}
await execute_tool("tool2", payload2, mock_tool_handler)
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
# Store memory again
await execute_tool("tool3", payload1, mock_tool_handler)
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 1
# Test endSession (camelCase)
payload3 = {
"tenant_id": "test-tenant",
"session_id": session_id,
"endSession": True,
"query": "end"
}
await execute_tool("tool4", payload3, mock_tool_handler)
assert len(memory.get_recent(session_id, ttl_seconds=900)) == 0
# =============================================================
# EDGE CASES
# =============================================================
def test_empty_session_id():
"""Test that empty session_id doesn't cause errors."""
memory.add_entry("", "tool1", {"data": "test"}, max_items=10, ttl_seconds=900)
# Should not store anything
assert len(memory.get_recent("", ttl_seconds=900)) == 0
def test_none_session_id():
"""Test that None session_id doesn't cause errors."""
# This shouldn't happen in practice, but test for safety
entries = memory.get_recent(None, ttl_seconds=900) # type: ignore
assert entries == []
@pytest.mark.asyncio
async def test_concurrent_sessions(mock_tool_handler):
"""Test that concurrent sessions don't interfere with each other."""
session1 = "session-concurrent-1"
session2 = "session-concurrent-2"
# Execute tools in both sessions concurrently
tasks = [
execute_tool("tool1", {
"tenant_id": "tenant1",
"session_id": session1,
"query": "q1"
}, mock_tool_handler),
execute_tool("tool2", {
"tenant_id": "tenant2",
"session_id": session2,
"query": "q2"
}, mock_tool_handler),
]
await asyncio.gather(*tasks)
# Each session should have its own memory
entries1 = memory.get_recent(session1, ttl_seconds=900)
entries2 = memory.get_recent(session2, ttl_seconds=900)
assert len(entries1) == 1
assert len(entries2) == 1
assert entries1[0]["tool"] == "tool1"
assert entries2[0]["tool"] == "tool2"
if __name__ == "__main__":
pytest.main([__file__, "-v"])