File size: 12,250 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
import json
import hashlib
import asyncio
from uuid import uuid4
from typing import Union, List, Dict, Optional, Tuple
from pydantic import Field
from .memory import BaseMemory
from evoagentx.rag import RAGConfig, RAGEngine
from evoagentx.rag.schema import Corpus, Chunk, ChunkMetadata, Query, RagResult
from evoagentx.storages.base import StorageHandler
from evoagentx.core.message import Message
from evoagentx.core.logging import logger
class LongTermMemory(BaseMemory):
"""
Manages long-term storage and retrieval of memories, integrating with RAGEngine for indexing
and StorageHandler for persistence.
"""
storage_handler: StorageHandler = Field(..., description="Handler for persistent storage")
rag_config: RAGConfig = Field(..., description="Configuration for RAG engine")
rag_engine: RAGEngine = Field(default=None, description="RAG engine for indexing and retrieval")
memory_table: str = Field(default="memory", description="Database table for storing memories")
default_corpus_id: Optional[str] = Field(default=None, description="Default corpus ID for memory indexing")
def init_module(self):
"""Initialize the RAG engine and memory indices."""
super().init_module()
if self.rag_engine is None:
self.rag_engine = RAGEngine(config=self.rag_config, storage_handler=self.storage_handler)
if self.default_corpus_id is None:
self.default_corpus_id = str(uuid4())
logger.info(f"Initialized LongTermMemory with corpus_id {self.default_corpus_id}")
def _create_memory_chunk(self, message: Message, memory_id: str) -> Chunk:
"""Convert a Message to a Chunk for RAG indexing."""
metadata = ChunkMetadata(
corpus_id=self.default_corpus_id,
memory_id=memory_id,
timestamp=message.timestamp,
action=message.action,
wf_goal=message.wf_goal,
agent=message.agent,
msg_type=message.msg_type.value if message.msg_type else None,
prompt=message.prompt,
next_actions=message.next_actions,
wf_task=message.wf_task,
wf_task_desc=message.wf_task_desc,
message_id=message.message_id,
content=json.dumps(message.content),
)
return Chunk(
chunk_id=memory_id,
text=str(message.content),
metadata=metadata,
start_char_idx=0,
end_char_idx=len(str(message.content)),
)
def _chunk_to_message(self, chunk: Chunk) -> Message:
"""Convert a Chunk to a Message object."""
return Message(
content=chunk.metadata.content,
action=chunk.metadata.action,
wf_goal=chunk.metadata.wf_goal,
timestamp=chunk.metadata.timestamp,
agent=chunk.metadata.agent,
msg_type=chunk.metadata.msg_type,
prompt=chunk.metadata.prompt,
next_actions=chunk.metadata.next_actions,
wf_task=chunk.metadata.wf_task,
wf_task_desc=chunk.metadata.wf_task_desc,
message_id=chunk.metadata.message_id,
)
def add(self, messages: Union[Message, str, List[Union[Message, str]]]) -> List[str]:
"""Store messages in memory and index them in RAGEngine, returning memory_ids."""
if not isinstance(messages, list):
messages = [messages]
messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages]
messages = [msg for msg in messages if msg.content] # Filter out empty messages
if not messages:
logger.warning("No valid messages to add")
return []
# Hash-based deduplication
existing_hashes = {
record["content_hash"]
for record in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, [])
if "content_hash" in record
}
memory_ids = [str(uuid4()) for _ in messages]
final_messages = []
final_memory_ids = []
final_chunks = []
for msg, memory_id in zip(messages, memory_ids):
content_hash = hashlib.sha256(str(msg.content).encode()).hexdigest()
if content_hash in existing_hashes:
logger.info(f"Duplicate message found (hash): {msg.content[:50]}...")
existing_id = next(
(r["memory_id"] for r in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, [])
if r.get("content_hash") == content_hash), None
)
if existing_id:
final_memory_ids.append(existing_id)
continue
final_messages.append(msg)
final_memory_ids.append(memory_id)
chunk = self._create_memory_chunk(msg, memory_id)
chunk.metadata.content_hash = content_hash
final_chunks.append(chunk)
if not final_chunks:
logger.info("No messages added after deduplication")
return final_memory_ids
# Add to in-memory messages
for msg in final_messages:
super().add_message(msg)
# Index in RAGEngine
corpus = Corpus(chunks=final_chunks, corpus_id=self.default_corpus_id)
chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id)
if not chunk_ids:
logger.error("Failed to index memories")
return final_memory_ids
return final_memory_ids
async def get(self, memory_ids: Union[str, List[str]], return_chunk: bool = True) -> List[Tuple[Union[Chunk, Message], str]]:
"""Retrieve memories by memory_ids, returning (Message/Chunk, memory_id) tuples."""
if not isinstance(memory_ids, list):
memory_ids = [memory_ids]
if not memory_ids:
logger.warning("No memory_ids provided for get")
return []
try:
chunks = await self.rag_engine.aget(
corpus_id=self.default_corpus_id,
index_type=self.rag_config.index.index_type,
node_ids=memory_ids
)
results = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) if not return_chunk else (chunk, chunk.metadata.memory_id)
for chunk in chunks if chunk]
logger.info(f"Retrieved {len(results)} memories for memory_ids: {memory_ids}")
return results
except Exception as e:
logger.error(f"Failed to get memories: {str(e)}")
return []
def delete(self, memory_ids: Union[str, List[str]]) -> List[bool]:
"""Delete memories by memory_ids, returning success status for each."""
if not isinstance(memory_ids, list):
memory_ids = [memory_ids]
if not memory_ids:
logger.warning("No memory_ids provided for deletion")
return []
successes = [False] * len(memory_ids)
valid_memory_ids = []
existing_chunks = asyncio.run(self.get(memory_ids, return_chunk=True))
for idx, (chunk, mid) in enumerate(existing_chunks):
if chunk:
valid_memory_ids.append(mid)
super().remove_message(self._chunk_to_message(chunk))
successes[idx] = True
if not valid_memory_ids:
logger.info("No memories found for deletion")
return successes
# Remove from RAG index
self.rag_engine.delete(
corpus_id=self.default_corpus_id,
index_type=self.rag_config.index.index_type,
node_ids=valid_memory_ids
)
return successes
def update(self, updates: Union[Tuple[str, Union[Message, str]], List[Tuple[str, Union[Message, str]]]]) -> List[bool]:
"""Update memories with new content, returning success status for each."""
if not isinstance(updates, list):
updates = [updates]
updates = [(mid, Message(content=msg) if isinstance(msg, str) else msg) for mid, msg in updates]
updates_dict = {mid: msg for mid, msg in updates if msg.content}
if not updates_dict:
logger.warning("No valid updates provided")
return []
memory_ids = list(updates_dict.keys())
existing_memories = asyncio.run(self.get(memory_ids, return_chunk=False))
existing_dict = {mid: msg for msg, mid in existing_memories}
successes = [False] * len(updates)
final_updates = []
final_memory_ids = []
for mid, msg in updates_dict.items():
if mid not in existing_dict:
logger.warning(f"No memory found with memory_id {mid}")
continue
final_updates.append((mid, msg))
final_memory_ids.append(mid)
successes[memory_ids.index(mid)] = True
super().remove_message(existing_dict[mid])
if not final_updates:
logger.info("No memories updated")
return successes
chunks = [self._create_memory_chunk(msg, mid) for mid, msg in final_updates]
for msg in [msg for _, msg in final_updates]:
super().add_message(msg)
corpus = Corpus(chunks=chunks, corpus_id=self.default_corpus_id)
chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id)
if not chunk_ids:
logger.error(f"Failed to update memories in RAG index: {final_memory_ids}")
return [False] * len(updates)
return successes
async def search_async(self, query: Union[str, Query], n: Optional[int] = None,
metadata_filters: Optional[Dict] = None, return_chunk=False) -> List[Tuple[Message, str]]:
"""Retrieve messages from RAG index asynchronously based on a query, returning messages and memory_ids."""
if isinstance(query, str):
query_obj = Query(
query_str=query,
top_k=n or self.rag_config.retrieval.top_k,
metadata_filters=metadata_filters or {}
)
else:
query_obj = query
query_obj.top_k = n or self.rag_config.retrieval.top_k
if metadata_filters:
query_obj.metadata_filters = {**query_obj.metadata_filters, **metadata_filters} if query_obj.metadata_filters else metadata_filters
try:
result: RagResult = await self.rag_engine.query_async(query_obj, corpus_id=self.default_corpus_id)
if return_chunk:
return [(chunk, chunk.metadata.memory_id) for chunk in result.corpus.chunks]
else:
messages = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) for chunk in result.corpus.chunks]
logger.info(f"Retrieved {len(messages)} memories for query: {query_obj.query_str}")
return messages[:n] if n else messages
except Exception as e:
logger.error(f"Failed to search memories: {str(e)}")
return []
def search(self, query: Union[str, Query], n: Optional[int] = None,
metadata_filters: Optional[Dict] = None) -> List[Tuple[Message, str]]:
"""Synchronous wrapper for searching memories."""
return asyncio.run(self.search_async(query, n, metadata_filters))
def clear(self) -> None:
"""Clear all messages and indices."""
super().clear()
self.rag_engine.clear(corpus_id=self.default_corpus_id)
logger.info(f"Cleared LongTermMemory with corpus_id {self.default_corpus_id}")
def save(self, save_path: Optional[str] = None) -> None:
"""Save all indices and memory data to database."""
self.rag_engine.save(output_path=save_path, corpus_id=self.default_corpus_id, table=self.memory_table)
def load(self, save_path: Optional[str] = None) -> List[str]:
"""Load memory data from database and reconstruct indices, returning memory_ids."""
return self.rag_engine.load(source=save_path, corpus_id=self.default_corpus_id, table=self.memory_table) |