AgenticAI-RAG / src /memory /short_term_memory.py
GreymanT's picture
Upload 80 files
8bf4d58 verified
"""Short-term memory for conversation context."""
import logging
from typing import List, Dict, Optional, Any
from datetime import datetime
import tiktoken
from src.core.config import get_settings
logger = logging.getLogger(__name__)
class Message:
"""Represents a single message in the conversation."""
def __init__(
self,
role: str,
content: str,
timestamp: Optional[datetime] = None,
metadata: Optional[Dict[str, Any]] = None,
):
"""Initialize a message."""
self.role = role # 'user', 'assistant', 'system'
self.content = content
self.timestamp = timestamp or datetime.now()
self.metadata = metadata or {}
def to_dict(self) -> Dict[str, Any]:
"""Convert message to dictionary."""
return {
"role": self.role,
"content": self.content,
"timestamp": self.timestamp.isoformat(),
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Message":
"""Create message from dictionary."""
timestamp = datetime.fromisoformat(data["timestamp"]) if isinstance(data.get("timestamp"), str) else data.get("timestamp")
return cls(
role=data["role"],
content=data["content"],
timestamp=timestamp,
metadata=data.get("metadata", {}),
)
class ShortTermMemory:
"""Manages short-term conversation memory with token-aware windowing."""
def __init__(
self,
max_messages: Optional[int] = None,
max_tokens: Optional[int] = None,
model: str = "gpt-4",
):
"""Initialize short-term memory."""
self.settings = get_settings()
self.max_messages = max_messages or self.settings.short_term_memory_size
self.max_tokens = max_tokens or self.settings.max_context_tokens
self.model = model
try:
self.encoding = tiktoken.encoding_for_model(model)
except KeyError:
# Fallback to cl100k_base encoding
self.encoding = tiktoken.get_encoding("cl100k_base")
self.messages: List[Message] = []
def add_message(
self,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add a message to memory.
Args:
role: Message role ('user', 'assistant', 'system')
content: Message content
metadata: Optional metadata
"""
message = Message(role=role, content=content, metadata=metadata)
self.messages.append(message)
self._trim_if_needed()
def get_messages(
self,
include_metadata: bool = False,
format_for_llm: bool = True,
) -> List[Dict[str, Any]]:
"""
Get messages in memory.
Args:
include_metadata: Whether to include metadata
format_for_llm: Format as OpenAI chat format
Returns:
List of messages
"""
if format_for_llm:
return [
{"role": msg.role, "content": msg.content}
for msg in self.messages
]
else:
return [msg.to_dict() if include_metadata else {
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat(),
} for msg in self.messages]
def get_context(self, max_tokens: Optional[int] = None) -> str:
"""
Get conversation context as a formatted string.
Args:
max_tokens: Maximum tokens to include
Returns:
Formatted context string
"""
max_tokens = max_tokens or self.max_tokens
context_messages = self._get_messages_within_token_limit(max_tokens)
return "\n".join([
f"{msg.role}: {msg.content}"
for msg in context_messages
])
def count_tokens(self, text: str) -> int:
"""Count tokens in text."""
return len(self.encoding.encode(text))
def get_total_tokens(self) -> int:
"""Get total tokens in current messages."""
return sum(self.count_tokens(msg.content) for msg in self.messages)
def _get_messages_within_token_limit(
self, max_tokens: int
) -> List[Message]:
"""Get messages that fit within token limit."""
total_tokens = 0
selected_messages = []
# Start from most recent messages
for msg in reversed(self.messages):
msg_tokens = self.count_tokens(msg.content)
if total_tokens + msg_tokens <= max_tokens:
selected_messages.insert(0, msg)
total_tokens += msg_tokens
else:
break
return selected_messages
def _trim_if_needed(self) -> None:
"""Trim messages if they exceed limits."""
# Trim by message count
if len(self.messages) > self.max_messages:
self.messages = self.messages[-self.max_messages:]
# Trim by token count
total_tokens = self.get_total_tokens()
if total_tokens > self.max_tokens:
self.messages = self._get_messages_within_token_limit(self.max_tokens)
def clear(self) -> None:
"""Clear all messages."""
self.messages = []
def summarize(self) -> str:
"""
Create a summary of the conversation.
Returns:
Summary string
"""
if not self.messages:
return "No conversation history."
summary_parts = [
f"Conversation with {len(self.messages)} messages:",
]
for msg in self.messages[-5:]: # Last 5 messages
summary_parts.append(f"- {msg.role}: {msg.content[:100]}...")
return "\n".join(summary_parts)
def to_dict(self) -> Dict[str, Any]:
"""Convert memory to dictionary."""
return {
"messages": [msg.to_dict() for msg in self.messages],
"max_messages": self.max_messages,
"max_tokens": self.max_tokens,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ShortTermMemory":
"""Create memory from dictionary."""
memory = cls(
max_messages=data.get("max_messages"),
max_tokens=data.get("max_tokens"),
)
memory.messages = [
Message.from_dict(msg_data)
for msg_data in data.get("messages", [])
]
return memory