Spaces:
Sleeping
Sleeping
File size: 6,576 Bytes
8bf4d58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
"""Short-term memory for conversation context."""
import logging
from typing import List, Dict, Optional, Any
from datetime import datetime
import tiktoken
from src.core.config import get_settings
logger = logging.getLogger(__name__)
class Message:
"""Represents a single message in the conversation."""
def __init__(
self,
role: str,
content: str,
timestamp: Optional[datetime] = None,
metadata: Optional[Dict[str, Any]] = None,
):
"""Initialize a message."""
self.role = role # 'user', 'assistant', 'system'
self.content = content
self.timestamp = timestamp or datetime.now()
self.metadata = metadata or {}
def to_dict(self) -> Dict[str, Any]:
"""Convert message to dictionary."""
return {
"role": self.role,
"content": self.content,
"timestamp": self.timestamp.isoformat(),
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Message":
"""Create message from dictionary."""
timestamp = datetime.fromisoformat(data["timestamp"]) if isinstance(data.get("timestamp"), str) else data.get("timestamp")
return cls(
role=data["role"],
content=data["content"],
timestamp=timestamp,
metadata=data.get("metadata", {}),
)
class ShortTermMemory:
"""Manages short-term conversation memory with token-aware windowing."""
def __init__(
self,
max_messages: Optional[int] = None,
max_tokens: Optional[int] = None,
model: str = "gpt-4",
):
"""Initialize short-term memory."""
self.settings = get_settings()
self.max_messages = max_messages or self.settings.short_term_memory_size
self.max_tokens = max_tokens or self.settings.max_context_tokens
self.model = model
try:
self.encoding = tiktoken.encoding_for_model(model)
except KeyError:
# Fallback to cl100k_base encoding
self.encoding = tiktoken.get_encoding("cl100k_base")
self.messages: List[Message] = []
def add_message(
self,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add a message to memory.
Args:
role: Message role ('user', 'assistant', 'system')
content: Message content
metadata: Optional metadata
"""
message = Message(role=role, content=content, metadata=metadata)
self.messages.append(message)
self._trim_if_needed()
def get_messages(
self,
include_metadata: bool = False,
format_for_llm: bool = True,
) -> List[Dict[str, Any]]:
"""
Get messages in memory.
Args:
include_metadata: Whether to include metadata
format_for_llm: Format as OpenAI chat format
Returns:
List of messages
"""
if format_for_llm:
return [
{"role": msg.role, "content": msg.content}
for msg in self.messages
]
else:
return [msg.to_dict() if include_metadata else {
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat(),
} for msg in self.messages]
def get_context(self, max_tokens: Optional[int] = None) -> str:
"""
Get conversation context as a formatted string.
Args:
max_tokens: Maximum tokens to include
Returns:
Formatted context string
"""
max_tokens = max_tokens or self.max_tokens
context_messages = self._get_messages_within_token_limit(max_tokens)
return "\n".join([
f"{msg.role}: {msg.content}"
for msg in context_messages
])
def count_tokens(self, text: str) -> int:
"""Count tokens in text."""
return len(self.encoding.encode(text))
def get_total_tokens(self) -> int:
"""Get total tokens in current messages."""
return sum(self.count_tokens(msg.content) for msg in self.messages)
def _get_messages_within_token_limit(
self, max_tokens: int
) -> List[Message]:
"""Get messages that fit within token limit."""
total_tokens = 0
selected_messages = []
# Start from most recent messages
for msg in reversed(self.messages):
msg_tokens = self.count_tokens(msg.content)
if total_tokens + msg_tokens <= max_tokens:
selected_messages.insert(0, msg)
total_tokens += msg_tokens
else:
break
return selected_messages
def _trim_if_needed(self) -> None:
"""Trim messages if they exceed limits."""
# Trim by message count
if len(self.messages) > self.max_messages:
self.messages = self.messages[-self.max_messages:]
# Trim by token count
total_tokens = self.get_total_tokens()
if total_tokens > self.max_tokens:
self.messages = self._get_messages_within_token_limit(self.max_tokens)
def clear(self) -> None:
"""Clear all messages."""
self.messages = []
def summarize(self) -> str:
"""
Create a summary of the conversation.
Returns:
Summary string
"""
if not self.messages:
return "No conversation history."
summary_parts = [
f"Conversation with {len(self.messages)} messages:",
]
for msg in self.messages[-5:]: # Last 5 messages
summary_parts.append(f"- {msg.role}: {msg.content[:100]}...")
return "\n".join(summary_parts)
def to_dict(self) -> Dict[str, Any]:
"""Convert memory to dictionary."""
return {
"messages": [msg.to_dict() for msg in self.messages],
"max_messages": self.max_messages,
"max_tokens": self.max_tokens,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ShortTermMemory":
"""Create memory from dictionary."""
memory = cls(
max_messages=data.get("max_messages"),
max_tokens=data.get("max_tokens"),
)
memory.messages = [
Message.from_dict(msg_data)
for msg_data in data.get("messages", [])
]
return memory
|