File size: 2,758 Bytes
2b63102 0b87551 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """Per-session conversation memory. Recent turns are kept verbatim; older turns are
folded into a running summary so long chats stay within the model's context window."""
from dataclasses import dataclass, field
from typing import Callable
@dataclass
class ConversationMemory:
max_turns: int = 6 # summarize once the window exceeds this many turns
keep_recent: int = 4 # turns kept verbatim after summarizing
summary: str = ""
turns: list[tuple[str, str]] = field(default_factory=list) # (user, assistant)
def add_turn(self, user: str, assistant: str) -> None:
self.turns.append((user, assistant))
def needs_summary(self) -> bool:
return len(self.turns) > self.max_turns
def build_messages(self, new_query: str) -> list[tuple[str, str]]:
"""Compose the message list to send to the agent for this turn."""
messages: list[tuple[str, str]] = []
if self.summary:
messages.append(("system", f"Summary of the earlier conversation:\n{self.summary}"))
for user, assistant in self.turns:
messages.append(("user", user))
messages.append(("assistant", assistant))
messages.append(("user", new_query))
return messages
def summarize_if_needed(self, summarizer: Callable[[str], str]) -> bool:
"""If over the limit, fold older turns into the summary. Returns True if it ran."""
if not self.needs_summary():
return False
overflow = self.turns[: -self.keep_recent] if self.keep_recent else self.turns
if not overflow:
return False
transcript = self._render(overflow)
prior = f"Previous summary:\n{self.summary}\n\n" if self.summary else ""
self.summary = summarizer(f"{prior}New exchanges to fold in:\n{transcript}").strip()
self.turns = self.turns[-self.keep_recent :] if self.keep_recent else []
return True
@staticmethod
def _render(turns: list[tuple[str, str]]) -> str:
return "\n".join(f"User: {u}\nAssistant: {a}" for u, a in turns)
class SessionStore:
"""In-memory map of session id -> ConversationMemory."""
def __init__(self, max_turns: int = 6, keep_recent: int = 4):
self._sessions: dict[str, ConversationMemory] = {}
self._max_turns = max_turns
self._keep_recent = keep_recent
def get(self, session_id: str) -> ConversationMemory:
if session_id not in self._sessions:
self._sessions[session_id] = ConversationMemory(
max_turns=self._max_turns, keep_recent=self._keep_recent
)
return self._sessions[session_id]
def clear(self, session_id: str) -> None:
self._sessions.pop(session_id, None)
|