agentbench / tests /test_memory.py
Nomearod's picture
feat: add SQLite conversation sessions with session_id
9874438
"""Tests for conversation memory store."""
from __future__ import annotations
import pytest
from agent_bench.memory.store import ConversationStore
@pytest.fixture
def store(tmp_path) -> ConversationStore:
"""ConversationStore with a temp DB path."""
return ConversationStore(db_path=str(tmp_path / "test.db"))
class TestConversationStore:
def test_append_and_retrieve(self, store: ConversationStore):
"""Write 3 messages, read back in chronological order."""
store.append("s1", "user", "Hello")
store.append("s1", "assistant", "Hi there")
store.append("s1", "user", "How are you?")
history = store.get_history("s1")
assert len(history) == 3
assert history[0] == {"role": "user", "content": "Hello"}
assert history[1] == {"role": "assistant", "content": "Hi there"}
assert history[2] == {"role": "user", "content": "How are you?"}
def test_max_turns(self, store: ConversationStore):
"""max_turns=2 returns at most 4 messages (2 user + 2 assistant)."""
for i in range(10):
store.append("s1", "user", f"Q{i}")
store.append("s1", "assistant", f"A{i}")
history = store.get_history("s1", max_turns=2)
assert len(history) == 4 # 2 turns * 2 messages each
def test_separate_sessions(self, store: ConversationStore):
"""Two session_ids don't cross-contaminate."""
store.append("s1", "user", "Session 1 message")
store.append("s2", "user", "Session 2 message")
h1 = store.get_history("s1")
h2 = store.get_history("s2")
assert len(h1) == 1
assert len(h2) == 1
assert h1[0]["content"] == "Session 1 message"
assert h2[0]["content"] == "Session 2 message"
def test_empty_session(self, store: ConversationStore):
"""Non-existent session returns empty list."""
assert store.get_history("nonexistent") == []
def test_list_sessions(self, store: ConversationStore):
"""List all session IDs."""
store.append("alpha", "user", "msg")
store.append("beta", "user", "msg")
store.append("alpha", "user", "msg2")
sessions = store.list_sessions()
assert set(sessions) == {"alpha", "beta"}
def test_delete_session(self, store: ConversationStore):
"""Delete removes all messages for a session."""
store.append("s1", "user", "keep")
store.append("s2", "user", "delete me")
store.delete_session("s2")
assert store.get_history("s1") == [{"role": "user", "content": "keep"}]
assert store.get_history("s2") == []
def test_metadata_stored(self, store: ConversationStore):
"""Metadata is accepted without error (not exposed in get_history)."""
store.append("s1", "user", "test", metadata={"sources": ["doc.md"]})
history = store.get_history("s1")
assert len(history) == 1
def _make_session_app(tmp_path):
"""Create a test app WITH conversation store attached."""
import time as time_mod
from fastapi import FastAPI
from agent_bench.agents.orchestrator import Orchestrator
from agent_bench.core.config import AppConfig, MemoryConfig, ProviderConfig
from agent_bench.core.provider import MockProvider
from agent_bench.memory.store import ConversationStore
from agent_bench.rag.store import HybridStore
from agent_bench.serving.middleware import MetricsCollector, RequestMiddleware
from agent_bench.tools.calculator import CalculatorTool
from agent_bench.tools.registry import ToolRegistry
from tests.test_agent import FakeSearchTool
app = FastAPI(title="agent-bench-session-test")
registry = ToolRegistry()
registry.register(FakeSearchTool())
registry.register(CalculatorTool())
provider = MockProvider()
orchestrator = Orchestrator(
provider=provider, registry=registry, max_iterations=3
)
config = AppConfig(
provider=ProviderConfig(default="mock"),
memory=MemoryConfig(
enabled=True,
db_path=str(tmp_path / "test_sessions.db"),
max_turns=10,
),
)
conversation_store = ConversationStore(
db_path=config.memory.db_path
)
app.state.orchestrator = orchestrator
app.state.store = HybridStore(dimension=384)
app.state.conversation_store = conversation_store
app.state.config = config
app.state.system_prompt = "You are a test assistant."
app.state.start_time = time_mod.time()
app.state.metrics = MetricsCollector()
app.add_middleware(RequestMiddleware)
from agent_bench.serving.routes import router
app.include_router(router)
return app, conversation_store
class TestSessionIntegration:
@pytest.mark.asyncio
async def test_stateless_without_session_id(self, tmp_path):
"""session_id=None suppresses DB interaction even when store exists."""
from httpx import ASGITransport, AsyncClient
app, conv_store = _make_session_app(tmp_path)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post(
"/ask", json={"question": "test"}
)
assert response.status_code == 200
assert "answer" in response.json()
# No session_id → nothing stored
assert conv_store.list_sessions() == []
@pytest.mark.asyncio
async def test_session_stores_and_loads_history(self, tmp_path):
"""Two requests with same session_id: second uses stored history."""
from httpx import ASGITransport, AsyncClient
app, conv_store = _make_session_app(tmp_path)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
# First request with session_id
r1 = await client.post(
"/ask",
json={"question": "What is FastAPI?", "session_id": "sess-1"},
)
assert r1.status_code == 200
# Verify Q+A was stored
history = conv_store.get_history("sess-1")
assert len(history) == 2
assert history[0]["role"] == "user"
assert history[0]["content"] == "What is FastAPI?"
assert history[1]["role"] == "assistant"
# Second request in same session
r2 = await client.post(
"/ask",
json={
"question": "Tell me more about it",
"session_id": "sess-1",
},
)
assert r2.status_code == 200
# Now 4 messages stored (2 turns)
history = conv_store.get_history("sess-1")
assert len(history) == 4
@pytest.mark.asyncio
async def test_history_passed_to_orchestrator(self, tmp_path):
"""Verify the orchestrator actually receives history on follow-up."""
from httpx import ASGITransport, AsyncClient
from agent_bench.agents.orchestrator import AgentResponse
from agent_bench.core.types import TokenUsage
app, conv_store = _make_session_app(tmp_path)
# Seed a prior conversation turn in the store
conv_store.append("sess-2", "user", "What is FastAPI?")
conv_store.append("sess-2", "assistant", "FastAPI is a web framework.")
# Patch orchestrator.run to capture the history argument
captured_kwargs: dict = {}
fake_response = AgentResponse(
answer="Follow-up answer.",
sources=[],
iterations=1,
tools_used=[],
usage=TokenUsage(
input_tokens=100,
output_tokens=20,
estimated_cost_usd=0.0001,
),
provider="mock",
model="mock-1",
latency_ms=1.0,
)
async def spy_run(**kwargs):
captured_kwargs.update(kwargs)
return fake_response
app.state.orchestrator.run = spy_run
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
r = await client.post(
"/ask",
json={
"question": "Tell me more",
"session_id": "sess-2",
},
)
assert r.status_code == 200
# The orchestrator must have received the prior history
assert "history" in captured_kwargs
assert captured_kwargs["history"] is not None
assert len(captured_kwargs["history"]) == 2
assert captured_kwargs["history"][0]["content"] == "What is FastAPI?"
assert captured_kwargs["history"][1]["content"] == "FastAPI is a web framework."