Spaces:
Sleeping
Sleeping
| """Short-term memory for conversation context.""" | |
| import logging | |
| from typing import List, Dict, Optional, Any | |
| from datetime import datetime | |
| import tiktoken | |
| from src.core.config import get_settings | |
| logger = logging.getLogger(__name__) | |
| class Message: | |
| """Represents a single message in the conversation.""" | |
| def __init__( | |
| self, | |
| role: str, | |
| content: str, | |
| timestamp: Optional[datetime] = None, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| ): | |
| """Initialize a message.""" | |
| self.role = role # 'user', 'assistant', 'system' | |
| self.content = content | |
| self.timestamp = timestamp or datetime.now() | |
| self.metadata = metadata or {} | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert message to dictionary.""" | |
| return { | |
| "role": self.role, | |
| "content": self.content, | |
| "timestamp": self.timestamp.isoformat(), | |
| "metadata": self.metadata, | |
| } | |
| def from_dict(cls, data: Dict[str, Any]) -> "Message": | |
| """Create message from dictionary.""" | |
| timestamp = datetime.fromisoformat(data["timestamp"]) if isinstance(data.get("timestamp"), str) else data.get("timestamp") | |
| return cls( | |
| role=data["role"], | |
| content=data["content"], | |
| timestamp=timestamp, | |
| metadata=data.get("metadata", {}), | |
| ) | |
| class ShortTermMemory: | |
| """Manages short-term conversation memory with token-aware windowing.""" | |
| def __init__( | |
| self, | |
| max_messages: Optional[int] = None, | |
| max_tokens: Optional[int] = None, | |
| model: str = "gpt-4", | |
| ): | |
| """Initialize short-term memory.""" | |
| self.settings = get_settings() | |
| self.max_messages = max_messages or self.settings.short_term_memory_size | |
| self.max_tokens = max_tokens or self.settings.max_context_tokens | |
| self.model = model | |
| try: | |
| self.encoding = tiktoken.encoding_for_model(model) | |
| except KeyError: | |
| # Fallback to cl100k_base encoding | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| self.messages: List[Message] = [] | |
| def add_message( | |
| self, | |
| role: str, | |
| content: str, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| """ | |
| Add a message to memory. | |
| Args: | |
| role: Message role ('user', 'assistant', 'system') | |
| content: Message content | |
| metadata: Optional metadata | |
| """ | |
| message = Message(role=role, content=content, metadata=metadata) | |
| self.messages.append(message) | |
| self._trim_if_needed() | |
| def get_messages( | |
| self, | |
| include_metadata: bool = False, | |
| format_for_llm: bool = True, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Get messages in memory. | |
| Args: | |
| include_metadata: Whether to include metadata | |
| format_for_llm: Format as OpenAI chat format | |
| Returns: | |
| List of messages | |
| """ | |
| if format_for_llm: | |
| return [ | |
| {"role": msg.role, "content": msg.content} | |
| for msg in self.messages | |
| ] | |
| else: | |
| return [msg.to_dict() if include_metadata else { | |
| "role": msg.role, | |
| "content": msg.content, | |
| "timestamp": msg.timestamp.isoformat(), | |
| } for msg in self.messages] | |
| def get_context(self, max_tokens: Optional[int] = None) -> str: | |
| """ | |
| Get conversation context as a formatted string. | |
| Args: | |
| max_tokens: Maximum tokens to include | |
| Returns: | |
| Formatted context string | |
| """ | |
| max_tokens = max_tokens or self.max_tokens | |
| context_messages = self._get_messages_within_token_limit(max_tokens) | |
| return "\n".join([ | |
| f"{msg.role}: {msg.content}" | |
| for msg in context_messages | |
| ]) | |
| def count_tokens(self, text: str) -> int: | |
| """Count tokens in text.""" | |
| return len(self.encoding.encode(text)) | |
| def get_total_tokens(self) -> int: | |
| """Get total tokens in current messages.""" | |
| return sum(self.count_tokens(msg.content) for msg in self.messages) | |
| def _get_messages_within_token_limit( | |
| self, max_tokens: int | |
| ) -> List[Message]: | |
| """Get messages that fit within token limit.""" | |
| total_tokens = 0 | |
| selected_messages = [] | |
| # Start from most recent messages | |
| for msg in reversed(self.messages): | |
| msg_tokens = self.count_tokens(msg.content) | |
| if total_tokens + msg_tokens <= max_tokens: | |
| selected_messages.insert(0, msg) | |
| total_tokens += msg_tokens | |
| else: | |
| break | |
| return selected_messages | |
| def _trim_if_needed(self) -> None: | |
| """Trim messages if they exceed limits.""" | |
| # Trim by message count | |
| if len(self.messages) > self.max_messages: | |
| self.messages = self.messages[-self.max_messages:] | |
| # Trim by token count | |
| total_tokens = self.get_total_tokens() | |
| if total_tokens > self.max_tokens: | |
| self.messages = self._get_messages_within_token_limit(self.max_tokens) | |
| def clear(self) -> None: | |
| """Clear all messages.""" | |
| self.messages = [] | |
| def summarize(self) -> str: | |
| """ | |
| Create a summary of the conversation. | |
| Returns: | |
| Summary string | |
| """ | |
| if not self.messages: | |
| return "No conversation history." | |
| summary_parts = [ | |
| f"Conversation with {len(self.messages)} messages:", | |
| ] | |
| for msg in self.messages[-5:]: # Last 5 messages | |
| summary_parts.append(f"- {msg.role}: {msg.content[:100]}...") | |
| return "\n".join(summary_parts) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert memory to dictionary.""" | |
| return { | |
| "messages": [msg.to_dict() for msg in self.messages], | |
| "max_messages": self.max_messages, | |
| "max_tokens": self.max_tokens, | |
| } | |
| def from_dict(cls, data: Dict[str, Any]) -> "ShortTermMemory": | |
| """Create memory from dictionary.""" | |
| memory = cls( | |
| max_messages=data.get("max_messages"), | |
| max_tokens=data.get("max_tokens"), | |
| ) | |
| memory.messages = [ | |
| Message.from_dict(msg_data) | |
| for msg_data in data.get("messages", []) | |
| ] | |
| return memory | |