File size: 10,534 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from collections import defaultdict
from typing import Union, Optional, List, Dict
from collections import deque
from pydantic import Field, PositiveInt, field_validator

from ..core.module import BaseModule
from ..core.module_utils import generate_id, get_timestamp
from ..core.message import Message
from ..utils.utils import safe_remove

class BaseMemory(BaseModule):
    """Base class for memory implementations in the EvoAgentX framework.
    
    BaseMemory provides core functionality for storing, retrieving, and 
    filtering messages. It maintains a chronological list of messages while 
    also providing indices for efficient retrieval by action or workflow goal.
    
    Attributes:
        messages: List of stored Message objects.
        memory_id: Unique identifier for this memory instance.
        timestamp: Creation timestamp of this memory instance.
        capacity: Maximum number of messages that can be stored, or None for unlimited.
    """

    messages: List[Message] = Field(default_factory=list)
    memory_id: str = Field(default_factory=generate_id)
    timestamp: str = Field(default_factory=get_timestamp)
    capacity: Optional[PositiveInt] = Field(default=None, description="maximum of messages, None means there is no limit to the message number")

    def init_module(self):
        """Initialize memory indices.
        
        Creates default dictionaries for indexing messages by action and workflow goal.
        """
        self._by_action = defaultdict(list)
        self._by_wf_goal = defaultdict(list)

    @property
    def size(self) -> int:
        """Returns the current number of messages in memory.
        
        Returns:
            int: Number of messages currently stored.
        """
        return len(self.messages)
    
    def clear(self):
        """Clear all messages from memory.
        
        Removes all messages and resets all indices.
        """
        self.messages.clear()
        self._by_action.clear()
        self._by_wf_goal.clear()
    
    def remove_message(self, message: Message):
        """Remove a single message from memory.
        
        Removes the specified message from the main message list and all indices.
        If the message is not found in memory, no action is taken.
        
        Args:
            message: The message to be removed. The message will be removed from 
                   self.messages, self._by_action, and self._by_wf_goal.
        """
        if not message:
            return
        if message not in self.messages:
            return
        safe_remove(self.messages, message)
        if self._by_action and not message.action:
            safe_remove(self._by_action[message.action], message)
        if self._by_wf_goal and not message.wf_goal:
            safe_remove(self._by_wf_goal[message.wf_goal], message)

    def add_message(self, message: Message):
        """Store a single message in memory.
        
        Adds the message to the main list and relevant indices if it's not already stored.
        
        Args:
            message (Message): the message to be stored. 
        """
        if not message:
            return
        if message in self.messages:
            return
        self.messages.append(message)
        if self._by_action and not message.action:
            self._by_action[message.action].append(message)
        if self._by_wf_goal and not message.wf_goal:
            self._by_wf_goal[message.wf_goal].append(message)
    
    def add_messages(self, messages: Union[Message, List[Message]], **kwargs):
        """
        store (a) message(s) to the memory. 

        Args:
            messages (Union[Message, List[Message]]): the input messages can be a single message or a list of message.
        """
        if not isinstance(messages, list):
            messages = [messages]
        for message in messages:
            self.add_message(message)
    
    def get(self, n: int=None, **kwargs) -> List[Message]:
        """Retrieve recent messages from memory.
        
        Returns the most recent messages, up to the specified limit.
        
        Args: 
            n: The maximum number of messages to return. If None, returns all messages.
            **kwargs (Any): Additional parameters (unused in base implementation).
            
        Returns:
            A list of Message objects, ordered from oldest to newest.
            
        Raises:
            AssertionError: If n is negative.
        """
        assert n is None or n>=0, "n must be None or a positive int"
        messages = self.messages if n is None else self.messages[-n:]
        return messages

    def get_by_type(self, data: Dict[str, list], key: str, n: int = None, **kwargs) -> List[Message]:
        """
        Retrieve a list of Message objects from a given data dictionary `data` based on a specified type `key`.

        This function looks up the value associated with `key` in the `data` dictionary, which should be a list of messages. It then returns a subset of these messages according to the specified parameters.
        If `n` is provided, it limits the number of messages returned; otherwise, it may return the entire list. Additional keyword arguments (**kwargs) can be used to further filter or process the resulting messages.

        Args:
            data (Dict[str, list]): A dictionary where keys are type strings and values are lists of messages.
            key (str): The key in `data` identifying the specific list of messages to retrieve.
            n (int, optional): The maximum number of messages to return. If not provided, all messages under the given `key` may be returned.
            **kwargs (Any): Additional parameters for filtering or processing the messages.

        Returns:
            List[Message]: A list of messages corresponding to the given `key`, possibly filtered or truncated according to `n` and other provided keyword arguments.
        """
        if not data or key not in data:
            return []
        assert n is None or n>=0, "n must be None or a positive int"
        messages = data[key] if n is None else data[key][-n:]
        return messages
    
    def get_by_action(self, actions: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]:
        """
        return messages triggered by `actions` in the memory. 

        Args:
            actions: A single action name or list of action names to filter by.
            n: Maximum number of messages to return per action. If None, returns all matching messages.
            **kwargs (Any): Additional parameters (unused in base implementation).
            
        Returns:
            A list of Message objects, sorted by timestamp.
        """
        if isinstance(actions, str):
            actions = [actions]
        messages = []
        for action in actions:
            messages.extend(self.get_by_type(self._by_action, key=action, n=n, **kwargs))
        messages = Message.sort_by_timestamp(messages)
        return messages
    
    def get_by_wf_goal(self, wf_goals: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]:
        """
        return messages related to `wf_goals` in the memory. 

        Args:
            wf_goals: A single workflow goal or list of workflow goals to filter by.
            n: Maximum number of messages to return per workflow goal. If None, returns all matching messages.
            **kwargs (Any): Additional parameters (unused in base implementation).
            
        Returns:
            A list of Message objects, sorted by timestamp.
        """
        if isinstance(wf_goals, str):
            wf_goals = [wf_goals]
        messages = []
        for wf_goal in wf_goals:
            messages.append(self.get_by_type(self._by_wf_goal, key=wf_goal, n=n, **kwargs))
        messages = Message.sort_by_timestamp(messages)
        return messages


class ShortTermMemory(BaseModule):
    """
    Short-term memory implementation.
    
    Stores only the most recent N messages (like a sliding window).
    Unlike BaseMemory/LongTermMemory, this is purely in-memory cache 
    and does not persist to storage_handler or vector DB.

    Attributes:
        buffer: Internal deque holding Message objects, capped at max_size.
        max_size: Maximum number of messages to retain.
        memory_id: Unique identifier for this memory instance.
        timestamp: Creation timestamp.
    """

    buffer: List[Message] = Field(default_factory=list, exclude=True)
    max_size: PositiveInt = Field(default=5, description="Maximum number of messages to keep in short-term memory")
    memory_id: str = Field(default_factory=generate_id)
    timestamp: str = Field(default_factory=get_timestamp)

    @field_validator("buffer", mode="before")
    @classmethod
    def ensure_list(cls, v):
        """Ensure that the buffer is always a list, even if it is null in the JSON."""
        if v is None:
            return []
        return v

    # Convert to deque during initialization
    def model_post_init(self, __context=None):
        """
        Pydantic V2 hook after model initialization.
        Convert buffer list → deque, enforce max_size.
        """
        self.buffer = deque(self.buffer, maxlen=self.max_size)

    @property
    def size(self) -> int:
        """Return current number of messages stored."""
        return len(self.buffer)

    def clear(self):
        """Clear all short-term memory."""
        self.buffer.clear()

    def add_message(self, message: Message):
        """Add a single message to short-term memory."""
        if not message:
            return
        self.buffer.append(message)

    def add_messages(self, messages: Union[Message, List[Message]]):
        """Add one or multiple messages."""
        if not isinstance(messages, list):
            messages = [messages]
        for msg in messages:
            self.add_message(msg)

    def get(self, n: Optional[int] = None) -> List[Message]:
        """
        Retrieve the most recent n messages (default: all).
        
        Args:
            n: Number of messages to return. If None, return all.
        
        Returns:
            List of Message objects, oldest → newest.
        """
        if n is None:
            return list(self.buffer)
        return list(self.buffer)[-n:]

    def get_last(self) -> Optional[Message]:
        """Return the latest message, or None if empty."""
        return self.buffer[-1] if self.buffer else None