File size: 10,534 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
from collections import defaultdict
from typing import Union, Optional, List, Dict
from collections import deque
from pydantic import Field, PositiveInt, field_validator
from ..core.module import BaseModule
from ..core.module_utils import generate_id, get_timestamp
from ..core.message import Message
from ..utils.utils import safe_remove
class BaseMemory(BaseModule):
"""Base class for memory implementations in the EvoAgentX framework.
BaseMemory provides core functionality for storing, retrieving, and
filtering messages. It maintains a chronological list of messages while
also providing indices for efficient retrieval by action or workflow goal.
Attributes:
messages: List of stored Message objects.
memory_id: Unique identifier for this memory instance.
timestamp: Creation timestamp of this memory instance.
capacity: Maximum number of messages that can be stored, or None for unlimited.
"""
messages: List[Message] = Field(default_factory=list)
memory_id: str = Field(default_factory=generate_id)
timestamp: str = Field(default_factory=get_timestamp)
capacity: Optional[PositiveInt] = Field(default=None, description="maximum of messages, None means there is no limit to the message number")
def init_module(self):
"""Initialize memory indices.
Creates default dictionaries for indexing messages by action and workflow goal.
"""
self._by_action = defaultdict(list)
self._by_wf_goal = defaultdict(list)
@property
def size(self) -> int:
"""Returns the current number of messages in memory.
Returns:
int: Number of messages currently stored.
"""
return len(self.messages)
def clear(self):
"""Clear all messages from memory.
Removes all messages and resets all indices.
"""
self.messages.clear()
self._by_action.clear()
self._by_wf_goal.clear()
def remove_message(self, message: Message):
"""Remove a single message from memory.
Removes the specified message from the main message list and all indices.
If the message is not found in memory, no action is taken.
Args:
message: The message to be removed. The message will be removed from
self.messages, self._by_action, and self._by_wf_goal.
"""
if not message:
return
if message not in self.messages:
return
safe_remove(self.messages, message)
if self._by_action and not message.action:
safe_remove(self._by_action[message.action], message)
if self._by_wf_goal and not message.wf_goal:
safe_remove(self._by_wf_goal[message.wf_goal], message)
def add_message(self, message: Message):
"""Store a single message in memory.
Adds the message to the main list and relevant indices if it's not already stored.
Args:
message (Message): the message to be stored.
"""
if not message:
return
if message in self.messages:
return
self.messages.append(message)
if self._by_action and not message.action:
self._by_action[message.action].append(message)
if self._by_wf_goal and not message.wf_goal:
self._by_wf_goal[message.wf_goal].append(message)
def add_messages(self, messages: Union[Message, List[Message]], **kwargs):
"""
store (a) message(s) to the memory.
Args:
messages (Union[Message, List[Message]]): the input messages can be a single message or a list of message.
"""
if not isinstance(messages, list):
messages = [messages]
for message in messages:
self.add_message(message)
def get(self, n: int=None, **kwargs) -> List[Message]:
"""Retrieve recent messages from memory.
Returns the most recent messages, up to the specified limit.
Args:
n: The maximum number of messages to return. If None, returns all messages.
**kwargs (Any): Additional parameters (unused in base implementation).
Returns:
A list of Message objects, ordered from oldest to newest.
Raises:
AssertionError: If n is negative.
"""
assert n is None or n>=0, "n must be None or a positive int"
messages = self.messages if n is None else self.messages[-n:]
return messages
def get_by_type(self, data: Dict[str, list], key: str, n: int = None, **kwargs) -> List[Message]:
"""
Retrieve a list of Message objects from a given data dictionary `data` based on a specified type `key`.
This function looks up the value associated with `key` in the `data` dictionary, which should be a list of messages. It then returns a subset of these messages according to the specified parameters.
If `n` is provided, it limits the number of messages returned; otherwise, it may return the entire list. Additional keyword arguments (**kwargs) can be used to further filter or process the resulting messages.
Args:
data (Dict[str, list]): A dictionary where keys are type strings and values are lists of messages.
key (str): The key in `data` identifying the specific list of messages to retrieve.
n (int, optional): The maximum number of messages to return. If not provided, all messages under the given `key` may be returned.
**kwargs (Any): Additional parameters for filtering or processing the messages.
Returns:
List[Message]: A list of messages corresponding to the given `key`, possibly filtered or truncated according to `n` and other provided keyword arguments.
"""
if not data or key not in data:
return []
assert n is None or n>=0, "n must be None or a positive int"
messages = data[key] if n is None else data[key][-n:]
return messages
def get_by_action(self, actions: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]:
"""
return messages triggered by `actions` in the memory.
Args:
actions: A single action name or list of action names to filter by.
n: Maximum number of messages to return per action. If None, returns all matching messages.
**kwargs (Any): Additional parameters (unused in base implementation).
Returns:
A list of Message objects, sorted by timestamp.
"""
if isinstance(actions, str):
actions = [actions]
messages = []
for action in actions:
messages.extend(self.get_by_type(self._by_action, key=action, n=n, **kwargs))
messages = Message.sort_by_timestamp(messages)
return messages
def get_by_wf_goal(self, wf_goals: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]:
"""
return messages related to `wf_goals` in the memory.
Args:
wf_goals: A single workflow goal or list of workflow goals to filter by.
n: Maximum number of messages to return per workflow goal. If None, returns all matching messages.
**kwargs (Any): Additional parameters (unused in base implementation).
Returns:
A list of Message objects, sorted by timestamp.
"""
if isinstance(wf_goals, str):
wf_goals = [wf_goals]
messages = []
for wf_goal in wf_goals:
messages.append(self.get_by_type(self._by_wf_goal, key=wf_goal, n=n, **kwargs))
messages = Message.sort_by_timestamp(messages)
return messages
class ShortTermMemory(BaseModule):
"""
Short-term memory implementation.
Stores only the most recent N messages (like a sliding window).
Unlike BaseMemory/LongTermMemory, this is purely in-memory cache
and does not persist to storage_handler or vector DB.
Attributes:
buffer: Internal deque holding Message objects, capped at max_size.
max_size: Maximum number of messages to retain.
memory_id: Unique identifier for this memory instance.
timestamp: Creation timestamp.
"""
buffer: List[Message] = Field(default_factory=list, exclude=True)
max_size: PositiveInt = Field(default=5, description="Maximum number of messages to keep in short-term memory")
memory_id: str = Field(default_factory=generate_id)
timestamp: str = Field(default_factory=get_timestamp)
@field_validator("buffer", mode="before")
@classmethod
def ensure_list(cls, v):
"""Ensure that the buffer is always a list, even if it is null in the JSON."""
if v is None:
return []
return v
# Convert to deque during initialization
def model_post_init(self, __context=None):
"""
Pydantic V2 hook after model initialization.
Convert buffer list → deque, enforce max_size.
"""
self.buffer = deque(self.buffer, maxlen=self.max_size)
@property
def size(self) -> int:
"""Return current number of messages stored."""
return len(self.buffer)
def clear(self):
"""Clear all short-term memory."""
self.buffer.clear()
def add_message(self, message: Message):
"""Add a single message to short-term memory."""
if not message:
return
self.buffer.append(message)
def add_messages(self, messages: Union[Message, List[Message]]):
"""Add one or multiple messages."""
if not isinstance(messages, list):
messages = [messages]
for msg in messages:
self.add_message(msg)
def get(self, n: Optional[int] = None) -> List[Message]:
"""
Retrieve the most recent n messages (default: all).
Args:
n: Number of messages to return. If None, return all.
Returns:
List of Message objects, oldest → newest.
"""
if n is None:
return list(self.buffer)
return list(self.buffer)[-n:]
def get_last(self) -> Optional[Message]:
"""Return the latest message, or None if empty."""
return self.buffer[-1] if self.buffer else None |