|
|
import json |
|
|
import hashlib |
|
|
import asyncio |
|
|
from uuid import uuid4 |
|
|
from typing import Union, List, Dict, Optional, Tuple |
|
|
|
|
|
from pydantic import Field |
|
|
|
|
|
from .memory import BaseMemory |
|
|
from evoagentx.rag import RAGConfig, RAGEngine |
|
|
from evoagentx.rag.schema import Corpus, Chunk, ChunkMetadata, Query, RagResult |
|
|
from evoagentx.storages.base import StorageHandler |
|
|
from evoagentx.core.message import Message |
|
|
from evoagentx.core.logging import logger |
|
|
|
|
|
class LongTermMemory(BaseMemory): |
|
|
""" |
|
|
Manages long-term storage and retrieval of memories, integrating with RAGEngine for indexing |
|
|
and StorageHandler for persistence. |
|
|
""" |
|
|
storage_handler: StorageHandler = Field(..., description="Handler for persistent storage") |
|
|
rag_config: RAGConfig = Field(..., description="Configuration for RAG engine") |
|
|
rag_engine: RAGEngine = Field(default=None, description="RAG engine for indexing and retrieval") |
|
|
memory_table: str = Field(default="memory", description="Database table for storing memories") |
|
|
default_corpus_id: Optional[str] = Field(default=None, description="Default corpus ID for memory indexing") |
|
|
|
|
|
def init_module(self): |
|
|
"""Initialize the RAG engine and memory indices.""" |
|
|
super().init_module() |
|
|
if self.rag_engine is None: |
|
|
self.rag_engine = RAGEngine(config=self.rag_config, storage_handler=self.storage_handler) |
|
|
if self.default_corpus_id is None: |
|
|
self.default_corpus_id = str(uuid4()) |
|
|
logger.info(f"Initialized LongTermMemory with corpus_id {self.default_corpus_id}") |
|
|
|
|
|
def _create_memory_chunk(self, message: Message, memory_id: str) -> Chunk: |
|
|
"""Convert a Message to a Chunk for RAG indexing.""" |
|
|
metadata = ChunkMetadata( |
|
|
corpus_id=self.default_corpus_id, |
|
|
memory_id=memory_id, |
|
|
timestamp=message.timestamp, |
|
|
action=message.action, |
|
|
wf_goal=message.wf_goal, |
|
|
agent=message.agent, |
|
|
msg_type=message.msg_type.value if message.msg_type else None, |
|
|
prompt=message.prompt, |
|
|
next_actions=message.next_actions, |
|
|
wf_task=message.wf_task, |
|
|
wf_task_desc=message.wf_task_desc, |
|
|
message_id=message.message_id, |
|
|
content=json.dumps(message.content), |
|
|
) |
|
|
return Chunk( |
|
|
chunk_id=memory_id, |
|
|
text=str(message.content), |
|
|
metadata=metadata, |
|
|
start_char_idx=0, |
|
|
end_char_idx=len(str(message.content)), |
|
|
) |
|
|
|
|
|
def _chunk_to_message(self, chunk: Chunk) -> Message: |
|
|
"""Convert a Chunk to a Message object.""" |
|
|
return Message( |
|
|
content=chunk.metadata.content, |
|
|
action=chunk.metadata.action, |
|
|
wf_goal=chunk.metadata.wf_goal, |
|
|
timestamp=chunk.metadata.timestamp, |
|
|
agent=chunk.metadata.agent, |
|
|
msg_type=chunk.metadata.msg_type, |
|
|
prompt=chunk.metadata.prompt, |
|
|
next_actions=chunk.metadata.next_actions, |
|
|
wf_task=chunk.metadata.wf_task, |
|
|
wf_task_desc=chunk.metadata.wf_task_desc, |
|
|
message_id=chunk.metadata.message_id, |
|
|
) |
|
|
|
|
|
def add(self, messages: Union[Message, str, List[Union[Message, str]]]) -> List[str]: |
|
|
"""Store messages in memory and index them in RAGEngine, returning memory_ids.""" |
|
|
if not isinstance(messages, list): |
|
|
messages = [messages] |
|
|
messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages] |
|
|
messages = [msg for msg in messages if msg.content] |
|
|
|
|
|
if not messages: |
|
|
logger.warning("No valid messages to add") |
|
|
return [] |
|
|
|
|
|
|
|
|
existing_hashes = { |
|
|
record["content_hash"] |
|
|
for record in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, []) |
|
|
if "content_hash" in record |
|
|
} |
|
|
memory_ids = [str(uuid4()) for _ in messages] |
|
|
final_messages = [] |
|
|
final_memory_ids = [] |
|
|
final_chunks = [] |
|
|
|
|
|
for msg, memory_id in zip(messages, memory_ids): |
|
|
content_hash = hashlib.sha256(str(msg.content).encode()).hexdigest() |
|
|
if content_hash in existing_hashes: |
|
|
logger.info(f"Duplicate message found (hash): {msg.content[:50]}...") |
|
|
existing_id = next( |
|
|
(r["memory_id"] for r in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, []) |
|
|
if r.get("content_hash") == content_hash), None |
|
|
) |
|
|
if existing_id: |
|
|
final_memory_ids.append(existing_id) |
|
|
continue |
|
|
final_messages.append(msg) |
|
|
final_memory_ids.append(memory_id) |
|
|
chunk = self._create_memory_chunk(msg, memory_id) |
|
|
chunk.metadata.content_hash = content_hash |
|
|
final_chunks.append(chunk) |
|
|
|
|
|
if not final_chunks: |
|
|
logger.info("No messages added after deduplication") |
|
|
return final_memory_ids |
|
|
|
|
|
|
|
|
for msg in final_messages: |
|
|
super().add_message(msg) |
|
|
|
|
|
|
|
|
corpus = Corpus(chunks=final_chunks, corpus_id=self.default_corpus_id) |
|
|
chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id) |
|
|
if not chunk_ids: |
|
|
logger.error("Failed to index memories") |
|
|
return final_memory_ids |
|
|
|
|
|
return final_memory_ids |
|
|
|
|
|
async def get(self, memory_ids: Union[str, List[str]], return_chunk: bool = True) -> List[Tuple[Union[Chunk, Message], str]]: |
|
|
"""Retrieve memories by memory_ids, returning (Message/Chunk, memory_id) tuples.""" |
|
|
if not isinstance(memory_ids, list): |
|
|
memory_ids = [memory_ids] |
|
|
|
|
|
if not memory_ids: |
|
|
logger.warning("No memory_ids provided for get") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
chunks = await self.rag_engine.aget( |
|
|
corpus_id=self.default_corpus_id, |
|
|
index_type=self.rag_config.index.index_type, |
|
|
node_ids=memory_ids |
|
|
) |
|
|
results = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) if not return_chunk else (chunk, chunk.metadata.memory_id) |
|
|
for chunk in chunks if chunk] |
|
|
logger.info(f"Retrieved {len(results)} memories for memory_ids: {memory_ids}") |
|
|
return results |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get memories: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def delete(self, memory_ids: Union[str, List[str]]) -> List[bool]: |
|
|
"""Delete memories by memory_ids, returning success status for each.""" |
|
|
if not isinstance(memory_ids, list): |
|
|
memory_ids = [memory_ids] |
|
|
|
|
|
if not memory_ids: |
|
|
logger.warning("No memory_ids provided for deletion") |
|
|
return [] |
|
|
|
|
|
successes = [False] * len(memory_ids) |
|
|
valid_memory_ids = [] |
|
|
|
|
|
existing_chunks = asyncio.run(self.get(memory_ids, return_chunk=True)) |
|
|
for idx, (chunk, mid) in enumerate(existing_chunks): |
|
|
if chunk: |
|
|
valid_memory_ids.append(mid) |
|
|
super().remove_message(self._chunk_to_message(chunk)) |
|
|
successes[idx] = True |
|
|
|
|
|
if not valid_memory_ids: |
|
|
logger.info("No memories found for deletion") |
|
|
return successes |
|
|
|
|
|
|
|
|
self.rag_engine.delete( |
|
|
corpus_id=self.default_corpus_id, |
|
|
index_type=self.rag_config.index.index_type, |
|
|
node_ids=valid_memory_ids |
|
|
) |
|
|
|
|
|
return successes |
|
|
|
|
|
def update(self, updates: Union[Tuple[str, Union[Message, str]], List[Tuple[str, Union[Message, str]]]]) -> List[bool]: |
|
|
"""Update memories with new content, returning success status for each.""" |
|
|
if not isinstance(updates, list): |
|
|
updates = [updates] |
|
|
updates = [(mid, Message(content=msg) if isinstance(msg, str) else msg) for mid, msg in updates] |
|
|
updates_dict = {mid: msg for mid, msg in updates if msg.content} |
|
|
|
|
|
if not updates_dict: |
|
|
logger.warning("No valid updates provided") |
|
|
return [] |
|
|
|
|
|
memory_ids = list(updates_dict.keys()) |
|
|
existing_memories = asyncio.run(self.get(memory_ids, return_chunk=False)) |
|
|
existing_dict = {mid: msg for msg, mid in existing_memories} |
|
|
|
|
|
successes = [False] * len(updates) |
|
|
final_updates = [] |
|
|
final_memory_ids = [] |
|
|
|
|
|
for mid, msg in updates_dict.items(): |
|
|
if mid not in existing_dict: |
|
|
logger.warning(f"No memory found with memory_id {mid}") |
|
|
continue |
|
|
final_updates.append((mid, msg)) |
|
|
final_memory_ids.append(mid) |
|
|
successes[memory_ids.index(mid)] = True |
|
|
super().remove_message(existing_dict[mid]) |
|
|
|
|
|
if not final_updates: |
|
|
logger.info("No memories updated") |
|
|
return successes |
|
|
|
|
|
chunks = [self._create_memory_chunk(msg, mid) for mid, msg in final_updates] |
|
|
for msg in [msg for _, msg in final_updates]: |
|
|
super().add_message(msg) |
|
|
|
|
|
corpus = Corpus(chunks=chunks, corpus_id=self.default_corpus_id) |
|
|
chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id) |
|
|
if not chunk_ids: |
|
|
logger.error(f"Failed to update memories in RAG index: {final_memory_ids}") |
|
|
return [False] * len(updates) |
|
|
|
|
|
return successes |
|
|
|
|
|
async def search_async(self, query: Union[str, Query], n: Optional[int] = None, |
|
|
metadata_filters: Optional[Dict] = None, return_chunk=False) -> List[Tuple[Message, str]]: |
|
|
"""Retrieve messages from RAG index asynchronously based on a query, returning messages and memory_ids.""" |
|
|
if isinstance(query, str): |
|
|
query_obj = Query( |
|
|
query_str=query, |
|
|
top_k=n or self.rag_config.retrieval.top_k, |
|
|
metadata_filters=metadata_filters or {} |
|
|
) |
|
|
else: |
|
|
query_obj = query |
|
|
query_obj.top_k = n or self.rag_config.retrieval.top_k |
|
|
if metadata_filters: |
|
|
query_obj.metadata_filters = {**query_obj.metadata_filters, **metadata_filters} if query_obj.metadata_filters else metadata_filters |
|
|
|
|
|
try: |
|
|
result: RagResult = await self.rag_engine.query_async(query_obj, corpus_id=self.default_corpus_id) |
|
|
if return_chunk: |
|
|
return [(chunk, chunk.metadata.memory_id) for chunk in result.corpus.chunks] |
|
|
else: |
|
|
messages = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) for chunk in result.corpus.chunks] |
|
|
logger.info(f"Retrieved {len(messages)} memories for query: {query_obj.query_str}") |
|
|
return messages[:n] if n else messages |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to search memories: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def search(self, query: Union[str, Query], n: Optional[int] = None, |
|
|
metadata_filters: Optional[Dict] = None) -> List[Tuple[Message, str]]: |
|
|
"""Synchronous wrapper for searching memories.""" |
|
|
return asyncio.run(self.search_async(query, n, metadata_filters)) |
|
|
|
|
|
def clear(self) -> None: |
|
|
"""Clear all messages and indices.""" |
|
|
super().clear() |
|
|
self.rag_engine.clear(corpus_id=self.default_corpus_id) |
|
|
logger.info(f"Cleared LongTermMemory with corpus_id {self.default_corpus_id}") |
|
|
|
|
|
def save(self, save_path: Optional[str] = None) -> None: |
|
|
"""Save all indices and memory data to database.""" |
|
|
self.rag_engine.save(output_path=save_path, corpus_id=self.default_corpus_id, table=self.memory_table) |
|
|
|
|
|
def load(self, save_path: Optional[str] = None) -> List[str]: |
|
|
"""Load memory data from database and reconstruct indices, returning memory_ids.""" |
|
|
return self.rag_engine.load(source=save_path, corpus_id=self.default_corpus_id, table=self.memory_table) |