File size: 12,164 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import json
from uuid import uuid4
from datetime import datetime
from typing import Union, List, Dict, Any, Optional, Tuple

from pydantic import Field

from .long_term_memory import LongTermMemory
from ..rag.schema import Query
from ..core.logging import logger
from ..core.module import BaseModule
from ..core.message import Message, MessageType
from ..models.base_model import BaseLLM
from ..prompts.memory.manager import MANAGER_PROMPT


class MemoryManager(BaseModule):
    """
    The Memory Manager organizes and manages LongTermMemory data at a higher level.
    It retrieves data, processes it with optional LLM-based action inference, and stores new or updated data.
    It creates Message objects for agent use, combining user prompts with memory context.

    Attributes:
        memory (LongTermMemory): The LongTermMemory instance for storing and retrieving messages.
        llm (Optional[BaseLLM]): LLM for deciding memory operations.
        use_llm_management (bool): Toggle LLM-based memory management.
    """
    memory: LongTermMemory = Field(..., description="Long-term memory instance")
    llm: Optional[BaseLLM] = Field(default=None, description="LLM for deciding memory operations")
    use_llm_management: bool = Field(default=True, description="Toggle LLM-based memory management")

    def init_module(self):
        pass
    
    async def _prompt_llm_for_memory_operation(self, input_data: Dict[str, Any], relevant_data: List[Tuple[Message, str]] = None) -> Dict[str, Any]:
        """Prompt the LLM to decide memory operation (add, update, delete) and return structured JSON."""
        if not self.llm or not self.use_llm_management:
            return input_data  # Bypass LLM if disabled or no LLM provided

        relevant_data_str = '\n'.join([json.dumps({"message": msg.to_dict(), "memory_id": mid}) for msg, mid in (relevant_data or [])])
        prompt = MANAGER_PROMPT.replace("<<INPUT_DATA>>", json.dumps(input_data, ensure_ascii=False)).replace("<<RELEVANT_DATA>>", relevant_data_str)

        logger.info(f"Memory Manager LLM Prompt: \n\n{prompt}")
        try:
            response = self.llm.generate(prompt=prompt)
            parsed = json.loads(response.content.replace("```json", "").replace("```", "").strip())
            if parsed["action"] not in ["add", "update", "delete"]:
                raise ValueError(f"Invalid action: {parsed['action']}")
            if parsed["action"] in ["update", "delete"] and not parsed.get("memory_id"):
                raise ValueError(f"memory_id required for {parsed['action']}")
            if parsed["action"] in ["add", "update"] and not parsed.get("message"):
                raise ValueError(f"message required for {parsed['action']}")
            return parsed
        except Exception as e:
            logger.error(f"LLM failed to generate valid memory operation: {str(e)}")
            return input_data  # Fallback to input data on error

    async def handle_memory(
        self,
        action: str,
        user_prompt: Optional[Union[str, Message, Query]] = None,
        data: Optional[Union[Message, str, List[Union[Message, str]], Dict, List[Tuple[str, Union[Message, str]]]]] = None,
        top_k: Optional[int] = None,
        metadata_filters: Optional[Dict] = None
    ) -> Union[List[str], List[Tuple[Message, str]], List[bool], Message, None]:
        """
        Handle memory operations based on the specified action, with optional LLM inference.

        Args:
            action (str): The memory operation ("add", "search", "get", "update", "delete", "clear", "save", "load", "create_message").
            user_prompt (Optional[Union[str, Message, Query]]): The user prompt or query to process with memory data.
            data (Optional): Input data for the operation (e.g., messages, memory IDs, updates).
            top_k (Optional[int]): Number of results to retrieve for search operations.
            metadata_filters (Optional[Dict]): Filters for memory retrieval.

        Returns:
            Union[List[str], List[Tuple[Message, str]], List[bool], Message, None]: Result of the operation.
        """
        if action not in ["add", "search", "get", "update", "delete", "clear", "save", "load", "create_message"]:
            logger.error(f"Invalid action: {action}")
            raise ValueError(f"Invalid action: {action}")

        if action == "add":
            if not data:
                logger.warning("No data provided for add operation")
                return []
            if not isinstance(data, list):
                data = [data]
            messages = [
                Message(
                    content=msg if isinstance(msg, str) else msg.content,
                    msg_type=MessageType.REQUEST if isinstance(msg, str) else msg.msg_type,
                    timestamp=datetime.now().isoformat() if isinstance(msg, str) else msg.timestamp,
                    agent="user" if isinstance(msg, str) else msg.agent,
                    message_id=str(uuid4()) if isinstance(msg, str) or not msg.message_id else msg.message_id
                ) for msg in data
            ]
            input_data = [
                {
                    "action": "add",
                    "memory_id": str(uuid4()),
                    "message": msg.to_dict()
                } for msg in messages
            ]
            if self.use_llm_management and self.llm:
                llm_decisions = await self._prompt_llm_for_memory_operation(input_data)
                final_messages = []
                final_memory_ids = []
                for decision, msg in zip(llm_decisions, messages):
                    if decision.get("action") != "add":
                        logger.info(f"LLM rejected adding memory: {decision}")
                        continue
                    final_messages.append(msg)
                    final_memory_ids.append(decision.get("memory_id"))
                return self.memory.add(final_messages) if final_messages else []
            return self.memory.add(messages)

        elif action == "search":
            if not user_prompt:
                logger.warning("No user_prompt provided for search operation")
                return []
            if isinstance(user_prompt, Message):
                user_prompt = user_prompt.content
            return await self.memory.search_async(user_prompt, top_k, metadata_filters)

        elif action == "get":
            if not data:
                logger.warning("No memory IDs provided for get operation")
                return []
            return await self.memory.get(data, return_chunk=False)

        elif action == "update":
            if not data:
                logger.warning("No updates provided for update operation")
                return []
            updates = [
                (mid, Message(
                    content=msg if isinstance(msg, str) else msg.content,
                    msg_type=MessageType.REQUEST if isinstance(msg, str) else msg.msg_type,
                    timestamp=datetime.now().isoformat(),
                    agent="user" if isinstance(msg, str) else msg.agent,
                    message_id=str(uuid4()) if isinstance(msg, str) or not msg.message_id else msg.message_id
                )) for mid, msg in (data if isinstance(data, list) else [data])
            ]
            input_data = [
                {
                    "action": "update",
                    "memory_id": mid,
                    "message": msg.to_dict()
                } for mid, msg in updates
            ]
            if self.use_llm_management and self.llm:
                existing_memories = await self.memory.get([mid for mid, _ in updates])
                llm_decisions = await self._prompt_llm_for_memory_operation(input_data, relevant_data=existing_memories)
                final_updates = []
                for decision, (mid, msg) in zip(llm_decisions, updates):
                    if decision.get("action") != "update":
                        logger.info(f"LLM rejected updating memory {mid}: {decision}")
                        continue
                    final_updates.append((mid, msg))
                return self.memory.update(final_updates) if final_updates else [False] * len(updates)
            return self.memory.update(updates)

        elif action == "delete":
            if not data:
                logger.warning("No memory IDs provided for delete operation")
                return []
            memory_ids = data if isinstance(data, list) else [data]
            if self.use_llm_management and self.llm:
                input_data = [{"action": "delete", "memory_id": mid} for mid in memory_ids]
                existing_memories = await self.memory.get(memory_ids)
                llm_decisions = await self._prompt_llm_for_memory_operation(input_data, relevant_data=existing_memories)
                valid_memory_ids = [decision.get("memory_id") for decision in llm_decisions if decision.get("action") == "delete"]
                return self.memory.delete(valid_memory_ids) if valid_memory_ids else [False] * len(memory_ids)
            return self.memory.delete(memory_ids)

        elif action == "clear":
            self.memory.clear()
            return None

        elif action == "save":
            self.memory.save(data)
            return None

        elif action == "load":
            return self.memory.load(data)

        elif action == "create_message":
            if not user_prompt:
                logger.warning("No user_prompt provided for create_message operation")
                return None
            if isinstance(user_prompt, Query):
                user_prompt = user_prompt.query_str
            elif isinstance(user_prompt, Message):
                user_prompt = user_prompt.content
            memories = await self.memory.search_async(user_prompt, top_k, metadata_filters)
            context = "\n".join([msg.content for msg, _ in memories])
            memory_ids = [mid for _, mid in memories]
            combined_content = f"User Prompt: {user_prompt}\nContext: {context}" if context else user_prompt
            return Message(
                content=combined_content,
                msg_type=MessageType.REQUEST,
                timestamp=datetime.now().isoformat(),
                agent="user",
                memory_ids=memory_ids
            )

    async def create_conversation_message(
        self,
        user_prompt: Union[str, Message],
        conversation_id: str,
        top_k: Optional[int] = None,
        metadata_filters: Optional[Dict] = None
    ) -> Message:
        """
        Create a Message combining user prompt with conversation history and relevant memories.

        Args:
            user_prompt (Union[str, Message]): The user's input prompt or message.
            conversation_id (str): ID of the conversation thread.
            top_k (Optional[int]): Number of results to retrieve.
            metadata_filters (Optional[Dict]): Filters for memory retrieval.

        Returns:
            Message: A new Message object with user prompt, history, and memory context.
        """
        if isinstance(user_prompt, Message):
            user_prompt = user_prompt.content

        # Retrieve conversation history
        history_filter = {"corpus_id": conversation_id}
        if metadata_filters:
            history_filter.update(metadata_filters)
        history_results = await self.memory.search_async(
            query=user_prompt, n=top_k or 10, metadata_filters=history_filter
        )
        history = "\n".join([f"{msg.content}" for msg, _ in history_results])

        # Combine prompt, history, and context
        combined_content = (
            f"User Prompt: \n{user_prompt}\n"
            f"Conversation History: \n\n{history or 'No history available'}\n"
        )
        return Message(
            content=combined_content,
            msg_type=MessageType.REQUEST,
            timestamp=datetime.now().isoformat(),
            agent="user",
            memory_ids=user_prompt.message_id if isinstance(user_prompt, Message) else str(uuid4())
        )