File size: 8,848 Bytes
9874438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""Tests for conversation memory store."""

from __future__ import annotations

import pytest

from agent_bench.memory.store import ConversationStore


@pytest.fixture
def store(tmp_path) -> ConversationStore:
    """ConversationStore with a temp DB path."""
    return ConversationStore(db_path=str(tmp_path / "test.db"))


class TestConversationStore:
    def test_append_and_retrieve(self, store: ConversationStore):
        """Write 3 messages, read back in chronological order."""
        store.append("s1", "user", "Hello")
        store.append("s1", "assistant", "Hi there")
        store.append("s1", "user", "How are you?")

        history = store.get_history("s1")
        assert len(history) == 3
        assert history[0] == {"role": "user", "content": "Hello"}
        assert history[1] == {"role": "assistant", "content": "Hi there"}
        assert history[2] == {"role": "user", "content": "How are you?"}

    def test_max_turns(self, store: ConversationStore):
        """max_turns=2 returns at most 4 messages (2 user + 2 assistant)."""
        for i in range(10):
            store.append("s1", "user", f"Q{i}")
            store.append("s1", "assistant", f"A{i}")

        history = store.get_history("s1", max_turns=2)
        assert len(history) == 4  # 2 turns * 2 messages each

    def test_separate_sessions(self, store: ConversationStore):
        """Two session_ids don't cross-contaminate."""
        store.append("s1", "user", "Session 1 message")
        store.append("s2", "user", "Session 2 message")

        h1 = store.get_history("s1")
        h2 = store.get_history("s2")

        assert len(h1) == 1
        assert len(h2) == 1
        assert h1[0]["content"] == "Session 1 message"
        assert h2[0]["content"] == "Session 2 message"

    def test_empty_session(self, store: ConversationStore):
        """Non-existent session returns empty list."""
        assert store.get_history("nonexistent") == []

    def test_list_sessions(self, store: ConversationStore):
        """List all session IDs."""
        store.append("alpha", "user", "msg")
        store.append("beta", "user", "msg")
        store.append("alpha", "user", "msg2")

        sessions = store.list_sessions()
        assert set(sessions) == {"alpha", "beta"}

    def test_delete_session(self, store: ConversationStore):
        """Delete removes all messages for a session."""
        store.append("s1", "user", "keep")
        store.append("s2", "user", "delete me")

        store.delete_session("s2")

        assert store.get_history("s1") == [{"role": "user", "content": "keep"}]
        assert store.get_history("s2") == []

    def test_metadata_stored(self, store: ConversationStore):
        """Metadata is accepted without error (not exposed in get_history)."""
        store.append("s1", "user", "test", metadata={"sources": ["doc.md"]})
        history = store.get_history("s1")
        assert len(history) == 1


def _make_session_app(tmp_path):
    """Create a test app WITH conversation store attached."""
    import time as time_mod

    from fastapi import FastAPI

    from agent_bench.agents.orchestrator import Orchestrator
    from agent_bench.core.config import AppConfig, MemoryConfig, ProviderConfig
    from agent_bench.core.provider import MockProvider
    from agent_bench.memory.store import ConversationStore
    from agent_bench.rag.store import HybridStore
    from agent_bench.serving.middleware import MetricsCollector, RequestMiddleware
    from agent_bench.tools.calculator import CalculatorTool
    from agent_bench.tools.registry import ToolRegistry
    from tests.test_agent import FakeSearchTool

    app = FastAPI(title="agent-bench-session-test")

    registry = ToolRegistry()
    registry.register(FakeSearchTool())
    registry.register(CalculatorTool())

    provider = MockProvider()
    orchestrator = Orchestrator(
        provider=provider, registry=registry, max_iterations=3
    )

    config = AppConfig(
        provider=ProviderConfig(default="mock"),
        memory=MemoryConfig(
            enabled=True,
            db_path=str(tmp_path / "test_sessions.db"),
            max_turns=10,
        ),
    )
    conversation_store = ConversationStore(
        db_path=config.memory.db_path
    )

    app.state.orchestrator = orchestrator
    app.state.store = HybridStore(dimension=384)
    app.state.conversation_store = conversation_store
    app.state.config = config
    app.state.system_prompt = "You are a test assistant."
    app.state.start_time = time_mod.time()
    app.state.metrics = MetricsCollector()

    app.add_middleware(RequestMiddleware)

    from agent_bench.serving.routes import router

    app.include_router(router)
    return app, conversation_store


class TestSessionIntegration:
    @pytest.mark.asyncio
    async def test_stateless_without_session_id(self, tmp_path):
        """session_id=None suppresses DB interaction even when store exists."""
        from httpx import ASGITransport, AsyncClient

        app, conv_store = _make_session_app(tmp_path)
        async with AsyncClient(
            transport=ASGITransport(app=app), base_url="http://test"
        ) as client:
            response = await client.post(
                "/ask", json={"question": "test"}
            )
        assert response.status_code == 200
        assert "answer" in response.json()
        # No session_id → nothing stored
        assert conv_store.list_sessions() == []

    @pytest.mark.asyncio
    async def test_session_stores_and_loads_history(self, tmp_path):
        """Two requests with same session_id: second uses stored history."""
        from httpx import ASGITransport, AsyncClient

        app, conv_store = _make_session_app(tmp_path)
        async with AsyncClient(
            transport=ASGITransport(app=app), base_url="http://test"
        ) as client:
            # First request with session_id
            r1 = await client.post(
                "/ask",
                json={"question": "What is FastAPI?", "session_id": "sess-1"},
            )
            assert r1.status_code == 200

            # Verify Q+A was stored
            history = conv_store.get_history("sess-1")
            assert len(history) == 2
            assert history[0]["role"] == "user"
            assert history[0]["content"] == "What is FastAPI?"
            assert history[1]["role"] == "assistant"

            # Second request in same session
            r2 = await client.post(
                "/ask",
                json={
                    "question": "Tell me more about it",
                    "session_id": "sess-1",
                },
            )
            assert r2.status_code == 200

            # Now 4 messages stored (2 turns)
            history = conv_store.get_history("sess-1")
            assert len(history) == 4

    @pytest.mark.asyncio
    async def test_history_passed_to_orchestrator(self, tmp_path):
        """Verify the orchestrator actually receives history on follow-up."""
        from httpx import ASGITransport, AsyncClient

        from agent_bench.agents.orchestrator import AgentResponse
        from agent_bench.core.types import TokenUsage

        app, conv_store = _make_session_app(tmp_path)

        # Seed a prior conversation turn in the store
        conv_store.append("sess-2", "user", "What is FastAPI?")
        conv_store.append("sess-2", "assistant", "FastAPI is a web framework.")

        # Patch orchestrator.run to capture the history argument
        captured_kwargs: dict = {}
        fake_response = AgentResponse(
            answer="Follow-up answer.",
            sources=[],
            iterations=1,
            tools_used=[],
            usage=TokenUsage(
                input_tokens=100,
                output_tokens=20,
                estimated_cost_usd=0.0001,
            ),
            provider="mock",
            model="mock-1",
            latency_ms=1.0,
        )

        async def spy_run(**kwargs):
            captured_kwargs.update(kwargs)
            return fake_response

        app.state.orchestrator.run = spy_run

        async with AsyncClient(
            transport=ASGITransport(app=app), base_url="http://test"
        ) as client:
            r = await client.post(
                "/ask",
                json={
                    "question": "Tell me more",
                    "session_id": "sess-2",
                },
            )
        assert r.status_code == 200

        # The orchestrator must have received the prior history
        assert "history" in captured_kwargs
        assert captured_kwargs["history"] is not None
        assert len(captured_kwargs["history"]) == 2
        assert captured_kwargs["history"][0]["content"] == "What is FastAPI?"
        assert captured_kwargs["history"][1]["content"] == "FastAPI is a web framework."