Spaces:
Running
Running
File size: 8,848 Bytes
9874438 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | """Tests for conversation memory store."""
from __future__ import annotations
import pytest
from agent_bench.memory.store import ConversationStore
@pytest.fixture
def store(tmp_path) -> ConversationStore:
"""ConversationStore with a temp DB path."""
return ConversationStore(db_path=str(tmp_path / "test.db"))
class TestConversationStore:
def test_append_and_retrieve(self, store: ConversationStore):
"""Write 3 messages, read back in chronological order."""
store.append("s1", "user", "Hello")
store.append("s1", "assistant", "Hi there")
store.append("s1", "user", "How are you?")
history = store.get_history("s1")
assert len(history) == 3
assert history[0] == {"role": "user", "content": "Hello"}
assert history[1] == {"role": "assistant", "content": "Hi there"}
assert history[2] == {"role": "user", "content": "How are you?"}
def test_max_turns(self, store: ConversationStore):
"""max_turns=2 returns at most 4 messages (2 user + 2 assistant)."""
for i in range(10):
store.append("s1", "user", f"Q{i}")
store.append("s1", "assistant", f"A{i}")
history = store.get_history("s1", max_turns=2)
assert len(history) == 4 # 2 turns * 2 messages each
def test_separate_sessions(self, store: ConversationStore):
"""Two session_ids don't cross-contaminate."""
store.append("s1", "user", "Session 1 message")
store.append("s2", "user", "Session 2 message")
h1 = store.get_history("s1")
h2 = store.get_history("s2")
assert len(h1) == 1
assert len(h2) == 1
assert h1[0]["content"] == "Session 1 message"
assert h2[0]["content"] == "Session 2 message"
def test_empty_session(self, store: ConversationStore):
"""Non-existent session returns empty list."""
assert store.get_history("nonexistent") == []
def test_list_sessions(self, store: ConversationStore):
"""List all session IDs."""
store.append("alpha", "user", "msg")
store.append("beta", "user", "msg")
store.append("alpha", "user", "msg2")
sessions = store.list_sessions()
assert set(sessions) == {"alpha", "beta"}
def test_delete_session(self, store: ConversationStore):
"""Delete removes all messages for a session."""
store.append("s1", "user", "keep")
store.append("s2", "user", "delete me")
store.delete_session("s2")
assert store.get_history("s1") == [{"role": "user", "content": "keep"}]
assert store.get_history("s2") == []
def test_metadata_stored(self, store: ConversationStore):
"""Metadata is accepted without error (not exposed in get_history)."""
store.append("s1", "user", "test", metadata={"sources": ["doc.md"]})
history = store.get_history("s1")
assert len(history) == 1
def _make_session_app(tmp_path):
"""Create a test app WITH conversation store attached."""
import time as time_mod
from fastapi import FastAPI
from agent_bench.agents.orchestrator import Orchestrator
from agent_bench.core.config import AppConfig, MemoryConfig, ProviderConfig
from agent_bench.core.provider import MockProvider
from agent_bench.memory.store import ConversationStore
from agent_bench.rag.store import HybridStore
from agent_bench.serving.middleware import MetricsCollector, RequestMiddleware
from agent_bench.tools.calculator import CalculatorTool
from agent_bench.tools.registry import ToolRegistry
from tests.test_agent import FakeSearchTool
app = FastAPI(title="agent-bench-session-test")
registry = ToolRegistry()
registry.register(FakeSearchTool())
registry.register(CalculatorTool())
provider = MockProvider()
orchestrator = Orchestrator(
provider=provider, registry=registry, max_iterations=3
)
config = AppConfig(
provider=ProviderConfig(default="mock"),
memory=MemoryConfig(
enabled=True,
db_path=str(tmp_path / "test_sessions.db"),
max_turns=10,
),
)
conversation_store = ConversationStore(
db_path=config.memory.db_path
)
app.state.orchestrator = orchestrator
app.state.store = HybridStore(dimension=384)
app.state.conversation_store = conversation_store
app.state.config = config
app.state.system_prompt = "You are a test assistant."
app.state.start_time = time_mod.time()
app.state.metrics = MetricsCollector()
app.add_middleware(RequestMiddleware)
from agent_bench.serving.routes import router
app.include_router(router)
return app, conversation_store
class TestSessionIntegration:
@pytest.mark.asyncio
async def test_stateless_without_session_id(self, tmp_path):
"""session_id=None suppresses DB interaction even when store exists."""
from httpx import ASGITransport, AsyncClient
app, conv_store = _make_session_app(tmp_path)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post(
"/ask", json={"question": "test"}
)
assert response.status_code == 200
assert "answer" in response.json()
# No session_id → nothing stored
assert conv_store.list_sessions() == []
@pytest.mark.asyncio
async def test_session_stores_and_loads_history(self, tmp_path):
"""Two requests with same session_id: second uses stored history."""
from httpx import ASGITransport, AsyncClient
app, conv_store = _make_session_app(tmp_path)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
# First request with session_id
r1 = await client.post(
"/ask",
json={"question": "What is FastAPI?", "session_id": "sess-1"},
)
assert r1.status_code == 200
# Verify Q+A was stored
history = conv_store.get_history("sess-1")
assert len(history) == 2
assert history[0]["role"] == "user"
assert history[0]["content"] == "What is FastAPI?"
assert history[1]["role"] == "assistant"
# Second request in same session
r2 = await client.post(
"/ask",
json={
"question": "Tell me more about it",
"session_id": "sess-1",
},
)
assert r2.status_code == 200
# Now 4 messages stored (2 turns)
history = conv_store.get_history("sess-1")
assert len(history) == 4
@pytest.mark.asyncio
async def test_history_passed_to_orchestrator(self, tmp_path):
"""Verify the orchestrator actually receives history on follow-up."""
from httpx import ASGITransport, AsyncClient
from agent_bench.agents.orchestrator import AgentResponse
from agent_bench.core.types import TokenUsage
app, conv_store = _make_session_app(tmp_path)
# Seed a prior conversation turn in the store
conv_store.append("sess-2", "user", "What is FastAPI?")
conv_store.append("sess-2", "assistant", "FastAPI is a web framework.")
# Patch orchestrator.run to capture the history argument
captured_kwargs: dict = {}
fake_response = AgentResponse(
answer="Follow-up answer.",
sources=[],
iterations=1,
tools_used=[],
usage=TokenUsage(
input_tokens=100,
output_tokens=20,
estimated_cost_usd=0.0001,
),
provider="mock",
model="mock-1",
latency_ms=1.0,
)
async def spy_run(**kwargs):
captured_kwargs.update(kwargs)
return fake_response
app.state.orchestrator.run = spy_run
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
r = await client.post(
"/ask",
json={
"question": "Tell me more",
"session_id": "sess-2",
},
)
assert r.status_code == 200
# The orchestrator must have received the prior history
assert "history" in captured_kwargs
assert captured_kwargs["history"] is not None
assert len(captured_kwargs["history"]) == 2
assert captured_kwargs["history"][0]["content"] == "What is FastAPI?"
assert captured_kwargs["history"][1]["content"] == "FastAPI is a web framework."
|