Spaces:
Paused
Paused
| """ | |
| Conversation memory manager with context window overflow handling. | |
| Keeps a sliding window of recent messages and summarizes older ones | |
| when the token budget is exceeded. | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| from dataclasses import dataclass, field | |
| logger = logging.getLogger(__name__) | |
| # Approximate token budget for conversation history (leave room for system + context) | |
| MAX_HISTORY_TOKENS = 1200 | |
| AVG_CHARS_PER_TOKEN = 4 # conservative estimate | |
| # Truncate individual messages before storing to prevent context bloat | |
| MAX_STORED_CHARS = 800 | |
| def estimate_tokens(text: str) -> int: | |
| return max(1, len(text) // AVG_CHARS_PER_TOKEN) | |
| class Message: | |
| role: str # "user" or "assistant" | |
| content: str | |
| def to_dict(self) -> Dict[str, str]: | |
| return {"role": self.role, "content": self.content} | |
| def token_count(self) -> int: | |
| return estimate_tokens(self.content) | |
| class ConversationMemory: | |
| def __init__(self, max_tokens: int = MAX_HISTORY_TOKENS): | |
| self.max_tokens = max_tokens | |
| self.messages: List[Message] = [] | |
| self.summary: Optional[str] = None # compressed older history | |
| def add(self, role: str, content: str): | |
| stored = content if len(content) <= MAX_STORED_CHARS else content[:MAX_STORED_CHARS] + " …" | |
| self.messages.append(Message(role=role, content=stored)) | |
| self._maybe_compress() | |
| def _total_tokens(self) -> int: | |
| t = sum(m.token_count() for m in self.messages) | |
| if self.summary: | |
| t += estimate_tokens(self.summary) | |
| return t | |
| def _maybe_compress(self): | |
| """Compress when over token budget or after every 3 exchanges (6 messages).""" | |
| if self._total_tokens() <= self.max_tokens and len(self.messages) <= 6: | |
| return | |
| # Keep the last 4 messages always (current exchange) | |
| keep_n = max(4, len(self.messages) // 2) | |
| to_compress = self.messages[:-keep_n] | |
| self.messages = self.messages[-keep_n:] | |
| if not to_compress: | |
| return | |
| # Build a simple bullet-point summary (no LLM call, to avoid circular deps) | |
| lines = [] | |
| for m in to_compress: | |
| snippet = m.content[:200].replace("\n", " ") | |
| lines.append(f"- [{m.role.upper()}]: {snippet}{'...' if len(m.content)>200 else ''}") | |
| new_summary = "\n".join(lines) | |
| if self.summary: | |
| self.summary = f"{self.summary}\n{new_summary}" | |
| else: | |
| self.summary = new_summary | |
| logger.info(f"Compressed {len(to_compress)} messages. History size: {len(self.messages)}") | |
| def get_history_for_prompt(self) -> List[Dict[str, str]]: | |
| """Return message list suitable for Ollama chat API.""" | |
| result = [] | |
| if self.summary: | |
| result.append({ | |
| "role": "user", | |
| "content": f"[Previous conversation summary]\n{self.summary}", | |
| }) | |
| result.append({ | |
| "role": "assistant", | |
| "content": "I understand the previous context.", | |
| }) | |
| result.extend(m.to_dict() for m in self.messages) | |
| return result | |
| def clear(self): | |
| self.messages = [] | |
| self.summary = None | |
| def to_gradio_format(self) -> List[List[Optional[str]]]: | |
| """Convert to Gradio chatbot format [[user, bot], ...]""" | |
| pairs = [] | |
| i = 0 | |
| msgs = self.messages | |
| while i < len(msgs): | |
| if msgs[i].role == "user": | |
| user_msg = msgs[i].content | |
| bot_msg = msgs[i+1].content if (i+1 < len(msgs) and msgs[i+1].role == "assistant") else None | |
| pairs.append([user_msg, bot_msg]) | |
| i += 2 if bot_msg is not None else 1 | |
| else: | |
| pairs.append([None, msgs[i].content]) | |
| i += 1 | |
| return pairs | |