Dokumentassistent / src /agent /memory.py
XQ
Refactor to Plan-and-Execute architecture
1441fa0
raw
history blame
4.12 kB
"""Conversation memory for multi-turn interactions.
Stores message history and retrieved sources across turns so that:
- Follow-up questions can reference prior context ("what about the other one?")
- The planner/synthesizer can see what was already discussed
- Previously retrieved sources are available without re-searching
"""
import logging
from dataclasses import dataclass, field
from src.models import QueryResult
logger = logging.getLogger(__name__)
_MAX_TURNS = 20
@dataclass
class Turn:
"""A single conversation turn.
Attributes:
query: The user's question.
answer: The assistant's response.
sources: Retrieved sources used to generate the answer.
"""
query: str
answer: str
sources: list[QueryResult] = field(default_factory=list)
class ConversationMemory:
"""Manages multi-turn conversation state.
Stores a rolling window of recent turns and provides formatted
context for the planner and synthesizer prompts.
"""
def __init__(self, max_turns: int = _MAX_TURNS) -> None:
"""Initialize conversation memory.
Args:
max_turns: Maximum number of turns to retain.
"""
self._max_turns = max_turns
self._turns: list[Turn] = []
@property
def turns(self) -> list[Turn]:
"""Return the list of conversation turns (read-only copy)."""
return list(self._turns)
@property
def is_empty(self) -> bool:
"""Return True if no conversation history exists."""
return len(self._turns) == 0
def add_turn(self, query: str, answer: str, sources: list[QueryResult] | None = None) -> None:
"""Record a completed conversation turn.
Args:
query: The user's question.
answer: The assistant's response.
sources: Retrieved sources (optional).
"""
self._turns.append(Turn(query=query, answer=answer, sources=sources or []))
if len(self._turns) > self._max_turns:
removed = self._turns.pop(0)
logger.debug("Evicted oldest turn: %s", removed.query[:50])
logger.debug("Memory now has %d turns", len(self._turns))
def clear(self) -> None:
"""Clear all conversation history."""
self._turns.clear()
logger.info("Conversation memory cleared")
def format_history(self, max_recent: int = 5) -> str:
"""Format recent conversation history for inclusion in prompts.
Args:
max_recent: Maximum number of recent turns to include.
Returns:
Formatted string of recent Q&A pairs, or empty string if no history.
"""
if not self._turns:
return ""
recent = self._turns[-max_recent:]
parts: list[str] = []
for i, turn in enumerate(recent, 1):
source_note = ""
if turn.sources:
doc_ids = sorted({s.chunk.document_id for s in turn.sources})
source_note = f" [sources: {', '.join(doc_ids)}]"
parts.append(
f"Turn {i}:\n"
f" User: {turn.query}\n"
f" Assistant: {turn.answer[:500]}{source_note}"
)
return "\n\n".join(parts)
def get_prior_sources(self) -> list[QueryResult]:
"""Return all unique sources from prior turns, sorted by score.
Returns:
Deduplicated list of QueryResult from all past turns.
"""
by_id: dict[str, QueryResult] = {}
for turn in self._turns:
for r in turn.sources:
cid = r.chunk.chunk_id
if cid not in by_id or r.score > by_id[cid].score:
by_id[cid] = r
return sorted(by_id.values(), key=lambda r: r.score, reverse=True)
def last_query(self) -> str:
"""Return the last user query, or empty string."""
return self._turns[-1].query if self._turns else ""
def last_sources(self) -> list[QueryResult]:
"""Return sources from the most recent turn."""
return self._turns[-1].sources if self._turns else []