File size: 8,819 Bytes
1441fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""Tests for conversation memory."""

from src.agent.memory import ConversationMemory, Turn
from src.models import DocumentChunk, QueryResult


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _qr(chunk_id: str = "c1", doc_id: str = "doc.pdf", score: float = 0.8) -> QueryResult:
    chunk = DocumentChunk(
        chunk_id=chunk_id, document_id=doc_id, text="text",
        metadata={"page_number": 1},
    )
    return QueryResult(chunk=chunk, score=score, source="test")


# ---------------------------------------------------------------------------
# Basic operations
# ---------------------------------------------------------------------------

class TestConversationMemory:
    def test_initially_empty(self) -> None:
        mem = ConversationMemory()
        assert mem.is_empty
        assert mem.turns == []
        assert mem.last_query() == ""
        assert mem.last_sources() == []

    def test_add_turn(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("What is X?", "X is Y.", [_qr()])
        assert not mem.is_empty
        assert len(mem.turns) == 1
        assert mem.last_query() == "What is X?"

    def test_multiple_turns(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("Q1", "A1")
        mem.add_turn("Q2", "A2")
        assert len(mem.turns) == 2
        assert mem.last_query() == "Q2"

    def test_clear(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("Q1", "A1")
        mem.clear()
        assert mem.is_empty

    def test_turns_returns_copy(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("Q1", "A1")
        turns = mem.turns
        turns.append(Turn(query="fake", answer="fake"))
        assert len(mem.turns) == 1  # original unaffected


# ---------------------------------------------------------------------------
# Eviction
# ---------------------------------------------------------------------------

class TestEviction:
    def test_max_turns_eviction(self) -> None:
        mem = ConversationMemory(max_turns=3)
        for i in range(5):
            mem.add_turn(f"Q{i}", f"A{i}")
        assert len(mem.turns) == 3
        # Oldest should be Q2 (Q0 and Q1 evicted)
        assert mem.turns[0].query == "Q2"

    def test_max_turns_one(self) -> None:
        mem = ConversationMemory(max_turns=1)
        mem.add_turn("Q1", "A1")
        mem.add_turn("Q2", "A2")
        assert len(mem.turns) == 1
        assert mem.turns[0].query == "Q2"


# ---------------------------------------------------------------------------
# format_history
# ---------------------------------------------------------------------------

class TestFormatHistory:
    def test_empty_history(self) -> None:
        mem = ConversationMemory()
        assert mem.format_history() == ""

    def test_includes_query_and_answer(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("What is X?", "X is a policy.")
        text = mem.format_history()
        assert "What is X?" in text
        assert "X is a policy." in text

    def test_includes_source_doc_ids(self) -> None:
        mem = ConversationMemory()
        sources = [_qr(doc_id="policy.pdf"), _qr(chunk_id="c2", doc_id="rules.pdf")]
        mem.add_turn("Q", "A", sources)
        text = mem.format_history()
        assert "policy.pdf" in text
        assert "rules.pdf" in text

    def test_max_recent_limits_output(self) -> None:
        mem = ConversationMemory()
        for i in range(10):
            mem.add_turn(f"Q{i}", f"A{i}")
        text = mem.format_history(max_recent=2)
        assert "Q8" in text
        assert "Q9" in text
        assert "Q0" not in text

    def test_long_answer_truncated(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("Q", "x" * 1000)
        text = mem.format_history()
        # Answer should be truncated to 500 chars
        assert len(text) < 1000


# ---------------------------------------------------------------------------
# get_prior_sources
# ---------------------------------------------------------------------------

class TestGetPriorSources:
    def test_empty_returns_empty(self) -> None:
        mem = ConversationMemory()
        assert mem.get_prior_sources() == []

    def test_collects_across_turns(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.8)])
        mem.add_turn("Q2", "A2", [_qr(chunk_id="c2", score=0.9)])
        sources = mem.get_prior_sources()
        assert len(sources) == 2
        # Sorted by score descending
        assert sources[0].score == 0.9

    def test_deduplicates_by_chunk_id(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.5)])
        mem.add_turn("Q2", "A2", [_qr(chunk_id="c1", score=0.9)])
        sources = mem.get_prior_sources()
        assert len(sources) == 1
        assert sources[0].score == 0.9  # keeps higher score

    def test_no_sources_turns(self) -> None:
        mem = ConversationMemory()
        mem.add_turn("Q1", "A1")  # no sources
        assert mem.get_prior_sources() == []


# ---------------------------------------------------------------------------
# Integration: memory in PlanAndExecuteRouter
# ---------------------------------------------------------------------------

class TestMemoryIntegration:
    def test_route_records_turn(self) -> None:
        """After route(), the conversation turn should be recorded in memory."""
        from unittest.mock import MagicMock, patch
        from langchain_core.messages import AIMessage
        from src.agent.plan_and_execute import PlanAndExecuteRouter

        llm = MagicMock()
        retriever = MagicMock()
        reranker = MagicMock()
        vector_store = MagicMock()
        memory = ConversationMemory()

        plan_json = '[{"action": "search", "detail": "test"}]'
        llm.invoke.side_effect = [plan_json, "The answer."]

        mock_agent = MagicMock()
        mock_agent.invoke.return_value = {"messages": [AIMessage(content="Found info.")]}

        router = PlanAndExecuteRouter(
            llm, retriever, reranker, vector_store, memory=memory,
        )

        with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
            router.route("test question", top_k=5)

        assert not memory.is_empty
        assert memory.last_query() == "test question"
        assert memory.turns[0].answer == "The answer."

    def test_history_injected_into_planner(self) -> None:
        """On a follow-up query, conversation history should appear in the planner prompt."""
        from unittest.mock import MagicMock, patch
        from langchain_core.messages import AIMessage
        from src.agent.plan_and_execute import PlanAndExecuteRouter

        llm = MagicMock()
        memory = ConversationMemory()
        memory.add_turn("What is the exam policy?", "The exam policy says...")

        plan_json = '[{"action": "search", "detail": "follow-up"}]'
        llm.invoke.side_effect = [plan_json, "Follow-up answer."]

        mock_agent = MagicMock()
        mock_agent.invoke.return_value = {"messages": [AIMessage(content="More info.")]}

        router = PlanAndExecuteRouter(
            llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
        )

        with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
            router.route("What about the grading?", top_k=5)

        # The first LLM call is the planner — check it includes history
        planner_prompt = llm.invoke.call_args_list[0][0][0]
        assert "exam policy" in planner_prompt
        assert "Conversation history" in planner_prompt

    def test_multi_turn_accumulates(self) -> None:
        """Multiple route() calls should accumulate turns in memory."""
        from unittest.mock import MagicMock, patch
        from langchain_core.messages import AIMessage
        from src.agent.plan_and_execute import PlanAndExecuteRouter

        llm = MagicMock()
        memory = ConversationMemory()

        mock_agent = MagicMock()
        mock_agent.invoke.return_value = {"messages": [AIMessage(content="info")]}

        router = PlanAndExecuteRouter(
            llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
        )

        for i in range(3):
            plan_json = f'[{{"action": "search", "detail": "q{i}"}}]'
            llm.invoke.side_effect = [plan_json, f"Answer {i}"]
            with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
                router.route(f"Question {i}", top_k=5)

        assert len(memory.turns) == 3