Dokumentassistent / tests /test_memory.py
XQ
Code cleaning
db45c50
raw
history blame
8.82 kB
"""Tests for conversation memory."""
from src.agent.memory import ConversationMemory, Turn
from src.models import DocumentChunk, QueryResult
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _qr(chunk_id: str = "c1", doc_id: str = "doc.pdf", score: float = 0.8) -> QueryResult:
chunk = DocumentChunk(
chunk_id=chunk_id, document_id=doc_id, text="text",
metadata={"page_number": 1},
)
return QueryResult(chunk=chunk, score=score, source="test")
# ---------------------------------------------------------------------------
# Basic operations
# ---------------------------------------------------------------------------
class TestConversationMemory:
def test_initially_empty(self) -> None:
mem = ConversationMemory()
assert mem.is_empty
assert mem.turns == []
assert mem.last_query() == ""
assert mem.last_sources() == []
def test_add_turn(self) -> None:
mem = ConversationMemory()
mem.add_turn("What is X?", "X is Y.", [_qr()])
assert not mem.is_empty
assert len(mem.turns) == 1
assert mem.last_query() == "What is X?"
def test_multiple_turns(self) -> None:
mem = ConversationMemory()
mem.add_turn("Q1", "A1")
mem.add_turn("Q2", "A2")
assert len(mem.turns) == 2
assert mem.last_query() == "Q2"
def test_clear(self) -> None:
mem = ConversationMemory()
mem.add_turn("Q1", "A1")
mem.clear()
assert mem.is_empty
def test_turns_returns_copy(self) -> None:
mem = ConversationMemory()
mem.add_turn("Q1", "A1")
turns = mem.turns
turns.append(Turn(query="fake", answer="fake"))
assert len(mem.turns) == 1 # original unaffected
# ---------------------------------------------------------------------------
# Eviction
# ---------------------------------------------------------------------------
class TestEviction:
def test_max_turns_eviction(self) -> None:
mem = ConversationMemory(max_turns=3)
for i in range(5):
mem.add_turn(f"Q{i}", f"A{i}")
assert len(mem.turns) == 3
# Oldest should be Q2 (Q0 and Q1 evicted)
assert mem.turns[0].query == "Q2"
def test_max_turns_one(self) -> None:
mem = ConversationMemory(max_turns=1)
mem.add_turn("Q1", "A1")
mem.add_turn("Q2", "A2")
assert len(mem.turns) == 1
assert mem.turns[0].query == "Q2"
# ---------------------------------------------------------------------------
# format_history
# ---------------------------------------------------------------------------
class TestFormatHistory:
def test_empty_history(self) -> None:
mem = ConversationMemory()
assert mem.format_history() == ""
def test_includes_query_and_answer(self) -> None:
mem = ConversationMemory()
mem.add_turn("What is X?", "X is a policy.")
text = mem.format_history()
assert "What is X?" in text
assert "X is a policy." in text
def test_includes_source_doc_ids(self) -> None:
mem = ConversationMemory()
sources = [_qr(doc_id="policy.pdf"), _qr(chunk_id="c2", doc_id="rules.pdf")]
mem.add_turn("Q", "A", sources)
text = mem.format_history()
assert "policy.pdf" in text
assert "rules.pdf" in text
def test_max_recent_limits_output(self) -> None:
mem = ConversationMemory()
for i in range(10):
mem.add_turn(f"Q{i}", f"A{i}")
text = mem.format_history(max_recent=2)
assert "Q8" in text
assert "Q9" in text
assert "Q0" not in text
def test_long_answer_truncated(self) -> None:
mem = ConversationMemory()
mem.add_turn("Q", "x" * 1000)
text = mem.format_history()
# Answer should be truncated to 500 chars
assert len(text) < 1000
# ---------------------------------------------------------------------------
# get_prior_sources
# ---------------------------------------------------------------------------
class TestGetPriorSources:
def test_empty_returns_empty(self) -> None:
mem = ConversationMemory()
assert mem.get_prior_sources() == []
def test_collects_across_turns(self) -> None:
mem = ConversationMemory()
mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.8)])
mem.add_turn("Q2", "A2", [_qr(chunk_id="c2", score=0.9)])
sources = mem.get_prior_sources()
assert len(sources) == 2
# Sorted by score descending
assert sources[0].score == 0.9
def test_deduplicates_by_chunk_id(self) -> None:
mem = ConversationMemory()
mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.5)])
mem.add_turn("Q2", "A2", [_qr(chunk_id="c1", score=0.9)])
sources = mem.get_prior_sources()
assert len(sources) == 1
assert sources[0].score == 0.9 # keeps higher score
def test_no_sources_turns(self) -> None:
mem = ConversationMemory()
mem.add_turn("Q1", "A1") # no sources
assert mem.get_prior_sources() == []
# ---------------------------------------------------------------------------
# Integration: memory in PlanAndExecuteRouter
# ---------------------------------------------------------------------------
class TestMemoryIntegration:
def test_route_records_turn(self) -> None:
"""After route(), the conversation turn should be recorded in memory."""
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from src.agent.plan_and_execute import PlanAndExecuteRouter
llm = MagicMock()
retriever = MagicMock()
reranker = MagicMock()
vector_store = MagicMock()
memory = ConversationMemory()
plan_json = '[{"action": "search", "detail": "test"}]'
llm.invoke.side_effect = [plan_json, "The answer."]
mock_agent = MagicMock()
mock_agent.invoke.return_value = {"messages": [AIMessage(content="Found info.")]}
router = PlanAndExecuteRouter(
llm, retriever, reranker, vector_store, memory=memory,
)
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
router.route("test question", top_k=5)
assert not memory.is_empty
assert memory.last_query() == "test question"
assert memory.turns[0].answer == "The answer."
def test_history_injected_into_planner(self) -> None:
"""On a follow-up query, conversation history should appear in the planner prompt."""
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from src.agent.plan_and_execute import PlanAndExecuteRouter
llm = MagicMock()
memory = ConversationMemory()
memory.add_turn("What is the exam policy?", "The exam policy says...")
plan_json = '[{"action": "search", "detail": "follow-up"}]'
llm.invoke.side_effect = [plan_json, "Follow-up answer."]
mock_agent = MagicMock()
mock_agent.invoke.return_value = {"messages": [AIMessage(content="More info.")]}
router = PlanAndExecuteRouter(
llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
)
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
router.route("What about the grading?", top_k=5)
# The first LLM call is the planner — check it includes history
planner_prompt = llm.invoke.call_args_list[0][0][0]
assert "exam policy" in planner_prompt
assert "Conversation history" in planner_prompt
def test_multi_turn_accumulates(self) -> None:
"""Multiple route() calls should accumulate turns in memory."""
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage
from src.agent.plan_and_execute import PlanAndExecuteRouter
llm = MagicMock()
memory = ConversationMemory()
mock_agent = MagicMock()
mock_agent.invoke.return_value = {"messages": [AIMessage(content="info")]}
router = PlanAndExecuteRouter(
llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
)
for i in range(3):
plan_json = f'[{{"action": "search", "detail": "q{i}"}}]'
llm.invoke.side_effect = [plan_json, f"Answer {i}"]
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
router.route(f"Question {i}", top_k=5)
assert len(memory.turns) == 3