Spaces:
Running
Running
| """Tests for conversation memory.""" | |
| from src.agent.memory import ConversationMemory, Turn | |
| from src.models import DocumentChunk, QueryResult | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _qr(chunk_id: str = "c1", doc_id: str = "doc.pdf", score: float = 0.8) -> QueryResult: | |
| chunk = DocumentChunk( | |
| chunk_id=chunk_id, document_id=doc_id, text="text", | |
| metadata={"page_number": 1}, | |
| ) | |
| return QueryResult(chunk=chunk, score=score, source="test") | |
| # --------------------------------------------------------------------------- | |
| # Basic operations | |
| # --------------------------------------------------------------------------- | |
| class TestConversationMemory: | |
| def test_initially_empty(self) -> None: | |
| mem = ConversationMemory() | |
| assert mem.is_empty | |
| assert mem.turns == [] | |
| assert mem.last_query() == "" | |
| assert mem.last_sources() == [] | |
| def test_add_turn(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("What is X?", "X is Y.", [_qr()]) | |
| assert not mem.is_empty | |
| assert len(mem.turns) == 1 | |
| assert mem.last_query() == "What is X?" | |
| def test_multiple_turns(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("Q1", "A1") | |
| mem.add_turn("Q2", "A2") | |
| assert len(mem.turns) == 2 | |
| assert mem.last_query() == "Q2" | |
| def test_clear(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("Q1", "A1") | |
| mem.clear() | |
| assert mem.is_empty | |
| def test_turns_returns_copy(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("Q1", "A1") | |
| turns = mem.turns | |
| turns.append(Turn(query="fake", answer="fake")) | |
| assert len(mem.turns) == 1 # original unaffected | |
| # --------------------------------------------------------------------------- | |
| # Eviction | |
| # --------------------------------------------------------------------------- | |
| class TestEviction: | |
| def test_max_turns_eviction(self) -> None: | |
| mem = ConversationMemory(max_turns=3) | |
| for i in range(5): | |
| mem.add_turn(f"Q{i}", f"A{i}") | |
| assert len(mem.turns) == 3 | |
| # Oldest should be Q2 (Q0 and Q1 evicted) | |
| assert mem.turns[0].query == "Q2" | |
| def test_max_turns_one(self) -> None: | |
| mem = ConversationMemory(max_turns=1) | |
| mem.add_turn("Q1", "A1") | |
| mem.add_turn("Q2", "A2") | |
| assert len(mem.turns) == 1 | |
| assert mem.turns[0].query == "Q2" | |
| # --------------------------------------------------------------------------- | |
| # format_history | |
| # --------------------------------------------------------------------------- | |
| class TestFormatHistory: | |
| def test_empty_history(self) -> None: | |
| mem = ConversationMemory() | |
| assert mem.format_history() == "" | |
| def test_includes_query_and_answer(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("What is X?", "X is a policy.") | |
| text = mem.format_history() | |
| assert "What is X?" in text | |
| assert "X is a policy." in text | |
| def test_includes_source_doc_ids(self) -> None: | |
| mem = ConversationMemory() | |
| sources = [_qr(doc_id="policy.pdf"), _qr(chunk_id="c2", doc_id="rules.pdf")] | |
| mem.add_turn("Q", "A", sources) | |
| text = mem.format_history() | |
| assert "policy.pdf" in text | |
| assert "rules.pdf" in text | |
| def test_max_recent_limits_output(self) -> None: | |
| mem = ConversationMemory() | |
| for i in range(10): | |
| mem.add_turn(f"Q{i}", f"A{i}") | |
| text = mem.format_history(max_recent=2) | |
| assert "Q8" in text | |
| assert "Q9" in text | |
| assert "Q0" not in text | |
| def test_long_answer_truncated(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("Q", "x" * 1000) | |
| text = mem.format_history() | |
| # Answer should be truncated to 500 chars | |
| assert len(text) < 1000 | |
| # --------------------------------------------------------------------------- | |
| # get_prior_sources | |
| # --------------------------------------------------------------------------- | |
| class TestGetPriorSources: | |
| def test_empty_returns_empty(self) -> None: | |
| mem = ConversationMemory() | |
| assert mem.get_prior_sources() == [] | |
| def test_collects_across_turns(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.8)]) | |
| mem.add_turn("Q2", "A2", [_qr(chunk_id="c2", score=0.9)]) | |
| sources = mem.get_prior_sources() | |
| assert len(sources) == 2 | |
| # Sorted by score descending | |
| assert sources[0].score == 0.9 | |
| def test_deduplicates_by_chunk_id(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.5)]) | |
| mem.add_turn("Q2", "A2", [_qr(chunk_id="c1", score=0.9)]) | |
| sources = mem.get_prior_sources() | |
| assert len(sources) == 1 | |
| assert sources[0].score == 0.9 # keeps higher score | |
| def test_no_sources_turns(self) -> None: | |
| mem = ConversationMemory() | |
| mem.add_turn("Q1", "A1") # no sources | |
| assert mem.get_prior_sources() == [] | |
| # --------------------------------------------------------------------------- | |
| # Integration: memory in PlanAndExecuteRouter | |
| # --------------------------------------------------------------------------- | |
| class TestMemoryIntegration: | |
| def test_route_records_turn(self) -> None: | |
| """After route(), the conversation turn should be recorded in memory.""" | |
| from unittest.mock import MagicMock, patch | |
| from langchain_core.messages import AIMessage | |
| from src.agent.plan_and_execute import PlanAndExecuteRouter | |
| llm = MagicMock() | |
| retriever = MagicMock() | |
| reranker = MagicMock() | |
| vector_store = MagicMock() | |
| memory = ConversationMemory() | |
| plan_json = '[{"action": "search", "detail": "test"}]' | |
| llm.invoke.side_effect = [plan_json, "The answer."] | |
| mock_agent = MagicMock() | |
| mock_agent.invoke.return_value = {"messages": [AIMessage(content="Found info.")]} | |
| router = PlanAndExecuteRouter( | |
| llm, retriever, reranker, vector_store, memory=memory, | |
| ) | |
| with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent): | |
| router.route("test question", top_k=5) | |
| assert not memory.is_empty | |
| assert memory.last_query() == "test question" | |
| assert memory.turns[0].answer == "The answer." | |
| def test_history_injected_into_planner(self) -> None: | |
| """On a follow-up query, conversation history should appear in the planner prompt.""" | |
| from unittest.mock import MagicMock, patch | |
| from langchain_core.messages import AIMessage | |
| from src.agent.plan_and_execute import PlanAndExecuteRouter | |
| llm = MagicMock() | |
| memory = ConversationMemory() | |
| memory.add_turn("What is the exam policy?", "The exam policy says...") | |
| plan_json = '[{"action": "search", "detail": "follow-up"}]' | |
| llm.invoke.side_effect = [plan_json, "Follow-up answer."] | |
| mock_agent = MagicMock() | |
| mock_agent.invoke.return_value = {"messages": [AIMessage(content="More info.")]} | |
| router = PlanAndExecuteRouter( | |
| llm, MagicMock(), MagicMock(), MagicMock(), memory=memory, | |
| ) | |
| with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent): | |
| router.route("What about the grading?", top_k=5) | |
| # The first LLM call is the planner — check it includes history | |
| planner_prompt = llm.invoke.call_args_list[0][0][0] | |
| assert "exam policy" in planner_prompt | |
| assert "Conversation history" in planner_prompt | |
| def test_multi_turn_accumulates(self) -> None: | |
| """Multiple route() calls should accumulate turns in memory.""" | |
| from unittest.mock import MagicMock, patch | |
| from langchain_core.messages import AIMessage | |
| from src.agent.plan_and_execute import PlanAndExecuteRouter | |
| llm = MagicMock() | |
| memory = ConversationMemory() | |
| mock_agent = MagicMock() | |
| mock_agent.invoke.return_value = {"messages": [AIMessage(content="info")]} | |
| router = PlanAndExecuteRouter( | |
| llm, MagicMock(), MagicMock(), MagicMock(), memory=memory, | |
| ) | |
| for i in range(3): | |
| plan_json = f'[{{"action": "search", "detail": "q{i}"}}]' | |
| llm.invoke.side_effect = [plan_json, f"Answer {i}"] | |
| with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent): | |
| router.route(f"Question {i}", top_k=5) | |
| assert len(memory.turns) == 3 | |