File size: 12,250 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import json
import hashlib
import asyncio
from uuid import uuid4
from typing import Union, List, Dict, Optional, Tuple

from pydantic import Field

from .memory import BaseMemory
from evoagentx.rag import RAGConfig, RAGEngine
from evoagentx.rag.schema import Corpus, Chunk, ChunkMetadata, Query, RagResult
from evoagentx.storages.base import StorageHandler
from evoagentx.core.message import Message
from evoagentx.core.logging import logger

class LongTermMemory(BaseMemory):
    """
    Manages long-term storage and retrieval of memories, integrating with RAGEngine for indexing
    and StorageHandler for persistence.
    """
    storage_handler: StorageHandler = Field(..., description="Handler for persistent storage")
    rag_config: RAGConfig = Field(..., description="Configuration for RAG engine")
    rag_engine: RAGEngine = Field(default=None, description="RAG engine for indexing and retrieval")
    memory_table: str = Field(default="memory", description="Database table for storing memories")
    default_corpus_id: Optional[str] = Field(default=None, description="Default corpus ID for memory indexing")

    def init_module(self):
        """Initialize the RAG engine and memory indices."""
        super().init_module()
        if self.rag_engine is None:
            self.rag_engine = RAGEngine(config=self.rag_config, storage_handler=self.storage_handler)
        if self.default_corpus_id is None:
            self.default_corpus_id = str(uuid4())
        logger.info(f"Initialized LongTermMemory with corpus_id {self.default_corpus_id}")

    def _create_memory_chunk(self, message: Message, memory_id: str) -> Chunk:
        """Convert a Message to a Chunk for RAG indexing."""
        metadata = ChunkMetadata(
            corpus_id=self.default_corpus_id,
            memory_id=memory_id,
            timestamp=message.timestamp,
            action=message.action,
            wf_goal=message.wf_goal,
            agent=message.agent,
            msg_type=message.msg_type.value if message.msg_type else None,
            prompt=message.prompt,
            next_actions=message.next_actions,
            wf_task=message.wf_task,
            wf_task_desc=message.wf_task_desc,
            message_id=message.message_id,
            content=json.dumps(message.content),
        )
        return Chunk(
            chunk_id=memory_id,
            text=str(message.content),
            metadata=metadata,
            start_char_idx=0,
            end_char_idx=len(str(message.content)),
        )

    def _chunk_to_message(self, chunk: Chunk) -> Message:
        """Convert a Chunk to a Message object."""
        return Message(
            content=chunk.metadata.content,
            action=chunk.metadata.action,
            wf_goal=chunk.metadata.wf_goal,
            timestamp=chunk.metadata.timestamp,
            agent=chunk.metadata.agent,
            msg_type=chunk.metadata.msg_type,
            prompt=chunk.metadata.prompt,
            next_actions=chunk.metadata.next_actions,
            wf_task=chunk.metadata.wf_task,
            wf_task_desc=chunk.metadata.wf_task_desc,
            message_id=chunk.metadata.message_id,
        )

    def add(self, messages: Union[Message, str, List[Union[Message, str]]]) -> List[str]:
        """Store messages in memory and index them in RAGEngine, returning memory_ids."""
        if not isinstance(messages, list):
            messages = [messages]
        messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages]
        messages = [msg for msg in messages if msg.content]  # Filter out empty messages

        if not messages:
            logger.warning("No valid messages to add")
            return []

        # Hash-based deduplication
        existing_hashes = {
            record["content_hash"]
            for record in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, [])
            if "content_hash" in record
        }
        memory_ids = [str(uuid4()) for _ in messages]
        final_messages = []
        final_memory_ids = []
        final_chunks = []

        for msg, memory_id in zip(messages, memory_ids):
            content_hash = hashlib.sha256(str(msg.content).encode()).hexdigest()
            if content_hash in existing_hashes:
                logger.info(f"Duplicate message found (hash): {msg.content[:50]}...")
                existing_id = next(
                    (r["memory_id"] for r in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, [])
                     if r.get("content_hash") == content_hash), None
                )
                if existing_id:
                    final_memory_ids.append(existing_id)
                    continue
            final_messages.append(msg)
            final_memory_ids.append(memory_id)
            chunk = self._create_memory_chunk(msg, memory_id)
            chunk.metadata.content_hash = content_hash
            final_chunks.append(chunk)

        if not final_chunks:
            logger.info("No messages added after deduplication")
            return final_memory_ids

        # Add to in-memory messages
        for msg in final_messages:
            super().add_message(msg)

        # Index in RAGEngine
        corpus = Corpus(chunks=final_chunks, corpus_id=self.default_corpus_id)
        chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id)
        if not chunk_ids:
            logger.error("Failed to index memories")
            return final_memory_ids

        return final_memory_ids

    async def get(self, memory_ids: Union[str, List[str]], return_chunk: bool = True) -> List[Tuple[Union[Chunk, Message], str]]:
        """Retrieve memories by memory_ids, returning (Message/Chunk, memory_id) tuples."""
        if not isinstance(memory_ids, list):
            memory_ids = [memory_ids]

        if not memory_ids:
            logger.warning("No memory_ids provided for get")
            return []

        try:
            chunks = await self.rag_engine.aget(
                corpus_id=self.default_corpus_id,
                index_type=self.rag_config.index.index_type,
                node_ids=memory_ids
            )
            results = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) if not return_chunk else (chunk, chunk.metadata.memory_id)
                       for chunk in chunks if chunk]
            logger.info(f"Retrieved {len(results)} memories for memory_ids: {memory_ids}")
            return results
        except Exception as e:
            logger.error(f"Failed to get memories: {str(e)}")
            return []

    def delete(self, memory_ids: Union[str, List[str]]) -> List[bool]:
        """Delete memories by memory_ids, returning success status for each."""
        if not isinstance(memory_ids, list):
            memory_ids = [memory_ids]

        if not memory_ids:
            logger.warning("No memory_ids provided for deletion")
            return []

        successes = [False] * len(memory_ids)
        valid_memory_ids = []

        existing_chunks = asyncio.run(self.get(memory_ids, return_chunk=True))
        for idx, (chunk, mid) in enumerate(existing_chunks):
            if chunk:
                valid_memory_ids.append(mid)
                super().remove_message(self._chunk_to_message(chunk))
                successes[idx] = True

        if not valid_memory_ids:
            logger.info("No memories found for deletion")
            return successes

        # Remove from RAG index
        self.rag_engine.delete(
            corpus_id=self.default_corpus_id,
            index_type=self.rag_config.index.index_type,
            node_ids=valid_memory_ids
        )

        return successes

    def update(self, updates: Union[Tuple[str, Union[Message, str]], List[Tuple[str, Union[Message, str]]]]) -> List[bool]:
        """Update memories with new content, returning success status for each."""
        if not isinstance(updates, list):
            updates = [updates]
        updates = [(mid, Message(content=msg) if isinstance(msg, str) else msg) for mid, msg in updates]
        updates_dict = {mid: msg for mid, msg in updates if msg.content}

        if not updates_dict:
            logger.warning("No valid updates provided")
            return []

        memory_ids = list(updates_dict.keys())
        existing_memories = asyncio.run(self.get(memory_ids, return_chunk=False))
        existing_dict = {mid: msg for msg, mid in existing_memories}

        successes = [False] * len(updates)
        final_updates = []
        final_memory_ids = []

        for mid, msg in updates_dict.items():
            if mid not in existing_dict:
                logger.warning(f"No memory found with memory_id {mid}")
                continue
            final_updates.append((mid, msg))
            final_memory_ids.append(mid)
            successes[memory_ids.index(mid)] = True
            super().remove_message(existing_dict[mid])

        if not final_updates:
            logger.info("No memories updated")
            return successes

        chunks = [self._create_memory_chunk(msg, mid) for mid, msg in final_updates]
        for msg in [msg for _, msg in final_updates]:
            super().add_message(msg)

        corpus = Corpus(chunks=chunks, corpus_id=self.default_corpus_id)
        chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id)
        if not chunk_ids:
            logger.error(f"Failed to update memories in RAG index: {final_memory_ids}")
            return [False] * len(updates)

        return successes

    async def search_async(self, query: Union[str, Query], n: Optional[int] = None,
                          metadata_filters: Optional[Dict] = None, return_chunk=False) -> List[Tuple[Message, str]]:
        """Retrieve messages from RAG index asynchronously based on a query, returning messages and memory_ids."""
        if isinstance(query, str):
            query_obj = Query(
                query_str=query,
                top_k=n or self.rag_config.retrieval.top_k,
                metadata_filters=metadata_filters or {}
            )
        else:
            query_obj = query
            query_obj.top_k = n or self.rag_config.retrieval.top_k
            if metadata_filters:
                query_obj.metadata_filters = {**query_obj.metadata_filters, **metadata_filters} if query_obj.metadata_filters else metadata_filters

        try:
            result: RagResult = await self.rag_engine.query_async(query_obj, corpus_id=self.default_corpus_id)
            if return_chunk:
                return [(chunk, chunk.metadata.memory_id) for chunk in result.corpus.chunks]
            else:
                messages = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) for chunk in result.corpus.chunks]
            logger.info(f"Retrieved {len(messages)} memories for query: {query_obj.query_str}")
            return messages[:n] if n else messages
        except Exception as e:
            logger.error(f"Failed to search memories: {str(e)}")
            return []

    def search(self, query: Union[str, Query], n: Optional[int] = None,
               metadata_filters: Optional[Dict] = None) -> List[Tuple[Message, str]]:
        """Synchronous wrapper for searching memories."""
        return asyncio.run(self.search_async(query, n, metadata_filters))

    def clear(self) -> None:
        """Clear all messages and indices."""
        super().clear()
        self.rag_engine.clear(corpus_id=self.default_corpus_id)
        logger.info(f"Cleared LongTermMemory with corpus_id {self.default_corpus_id}")

    def save(self, save_path: Optional[str] = None) -> None:
        """Save all indices and memory data to database."""
        self.rag_engine.save(output_path=save_path, corpus_id=self.default_corpus_id, table=self.memory_table)

    def load(self, save_path: Optional[str] = None) -> List[str]:
        """Load memory data from database and reconstruct indices, returning memory_ids."""
        return self.rag_engine.load(source=save_path, corpus_id=self.default_corpus_id, table=self.memory_table)