File size: 4,119 Bytes
1441fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Conversation memory for multi-turn interactions.

Stores message history and retrieved sources across turns so that:
- Follow-up questions can reference prior context ("what about the other one?")
- The planner/synthesizer can see what was already discussed
- Previously retrieved sources are available without re-searching
"""

import logging
from dataclasses import dataclass, field

from src.models import QueryResult

logger = logging.getLogger(__name__)

_MAX_TURNS = 20


@dataclass
class Turn:
    """A single conversation turn.

    Attributes:
        query: The user's question.
        answer: The assistant's response.
        sources: Retrieved sources used to generate the answer.
    """

    query: str
    answer: str
    sources: list[QueryResult] = field(default_factory=list)


class ConversationMemory:
    """Manages multi-turn conversation state.

    Stores a rolling window of recent turns and provides formatted
    context for the planner and synthesizer prompts.
    """

    def __init__(self, max_turns: int = _MAX_TURNS) -> None:
        """Initialize conversation memory.

        Args:
            max_turns: Maximum number of turns to retain.
        """
        self._max_turns = max_turns
        self._turns: list[Turn] = []

    @property
    def turns(self) -> list[Turn]:
        """Return the list of conversation turns (read-only copy)."""
        return list(self._turns)

    @property
    def is_empty(self) -> bool:
        """Return True if no conversation history exists."""
        return len(self._turns) == 0

    def add_turn(self, query: str, answer: str, sources: list[QueryResult] | None = None) -> None:
        """Record a completed conversation turn.

        Args:
            query: The user's question.
            answer: The assistant's response.
            sources: Retrieved sources (optional).
        """
        self._turns.append(Turn(query=query, answer=answer, sources=sources or []))
        if len(self._turns) > self._max_turns:
            removed = self._turns.pop(0)
            logger.debug("Evicted oldest turn: %s", removed.query[:50])
        logger.debug("Memory now has %d turns", len(self._turns))

    def clear(self) -> None:
        """Clear all conversation history."""
        self._turns.clear()
        logger.info("Conversation memory cleared")

    def format_history(self, max_recent: int = 5) -> str:
        """Format recent conversation history for inclusion in prompts.

        Args:
            max_recent: Maximum number of recent turns to include.

        Returns:
            Formatted string of recent Q&A pairs, or empty string if no history.
        """
        if not self._turns:
            return ""

        recent = self._turns[-max_recent:]
        parts: list[str] = []
        for i, turn in enumerate(recent, 1):
            source_note = ""
            if turn.sources:
                doc_ids = sorted({s.chunk.document_id for s in turn.sources})
                source_note = f" [sources: {', '.join(doc_ids)}]"
            parts.append(
                f"Turn {i}:\n"
                f"  User: {turn.query}\n"
                f"  Assistant: {turn.answer[:500]}{source_note}"
            )
        return "\n\n".join(parts)

    def get_prior_sources(self) -> list[QueryResult]:
        """Return all unique sources from prior turns, sorted by score.

        Returns:
            Deduplicated list of QueryResult from all past turns.
        """
        by_id: dict[str, QueryResult] = {}
        for turn in self._turns:
            for r in turn.sources:
                cid = r.chunk.chunk_id
                if cid not in by_id or r.score > by_id[cid].score:
                    by_id[cid] = r
        return sorted(by_id.values(), key=lambda r: r.score, reverse=True)

    def last_query(self) -> str:
        """Return the last user query, or empty string."""
        return self._turns[-1].query if self._turns else ""

    def last_sources(self) -> list[QueryResult]:
        """Return sources from the most recent turn."""
        return self._turns[-1].sources if self._turns else []