MultiModalRag / utils /memory.py
irajkoohi's picture
chore: update app [space deploy]
6c21523
Raw
History Blame Contribute Delete
3.89 kB
"""
Conversation memory manager with context window overflow handling.
Keeps a sliding window of recent messages and summarizes older ones
when the token budget is exceeded.
"""
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# Approximate token budget for conversation history (leave room for system + context)
MAX_HISTORY_TOKENS = 1200
AVG_CHARS_PER_TOKEN = 4 # conservative estimate
# Truncate individual messages before storing to prevent context bloat
MAX_STORED_CHARS = 800
def estimate_tokens(text: str) -> int:
return max(1, len(text) // AVG_CHARS_PER_TOKEN)
@dataclass
class Message:
role: str # "user" or "assistant"
content: str
def to_dict(self) -> Dict[str, str]:
return {"role": self.role, "content": self.content}
def token_count(self) -> int:
return estimate_tokens(self.content)
class ConversationMemory:
def __init__(self, max_tokens: int = MAX_HISTORY_TOKENS):
self.max_tokens = max_tokens
self.messages: List[Message] = []
self.summary: Optional[str] = None # compressed older history
def add(self, role: str, content: str):
stored = content if len(content) <= MAX_STORED_CHARS else content[:MAX_STORED_CHARS] + " …"
self.messages.append(Message(role=role, content=stored))
self._maybe_compress()
def _total_tokens(self) -> int:
t = sum(m.token_count() for m in self.messages)
if self.summary:
t += estimate_tokens(self.summary)
return t
def _maybe_compress(self):
"""Compress when over token budget or after every 3 exchanges (6 messages)."""
if self._total_tokens() <= self.max_tokens and len(self.messages) <= 6:
return
# Keep the last 4 messages always (current exchange)
keep_n = max(4, len(self.messages) // 2)
to_compress = self.messages[:-keep_n]
self.messages = self.messages[-keep_n:]
if not to_compress:
return
# Build a simple bullet-point summary (no LLM call, to avoid circular deps)
lines = []
for m in to_compress:
snippet = m.content[:200].replace("\n", " ")
lines.append(f"- [{m.role.upper()}]: {snippet}{'...' if len(m.content)>200 else ''}")
new_summary = "\n".join(lines)
if self.summary:
self.summary = f"{self.summary}\n{new_summary}"
else:
self.summary = new_summary
logger.info(f"Compressed {len(to_compress)} messages. History size: {len(self.messages)}")
def get_history_for_prompt(self) -> List[Dict[str, str]]:
"""Return message list suitable for Ollama chat API."""
result = []
if self.summary:
result.append({
"role": "user",
"content": f"[Previous conversation summary]\n{self.summary}",
})
result.append({
"role": "assistant",
"content": "I understand the previous context.",
})
result.extend(m.to_dict() for m in self.messages)
return result
def clear(self):
self.messages = []
self.summary = None
def to_gradio_format(self) -> List[List[Optional[str]]]:
"""Convert to Gradio chatbot format [[user, bot], ...]"""
pairs = []
i = 0
msgs = self.messages
while i < len(msgs):
if msgs[i].role == "user":
user_msg = msgs[i].content
bot_msg = msgs[i+1].content if (i+1 < len(msgs) and msgs[i+1].role == "assistant") else None
pairs.append([user_msg, bot_msg])
i += 2 if bot_msg is not None else 1
else:
pairs.append([None, msgs[i].content])
i += 1
return pairs