iLOVE2D's picture
Upload 2846 files
5374a2d verified
from collections import defaultdict
from typing import Union, Optional, List, Dict
from collections import deque
from pydantic import Field, PositiveInt, field_validator
from ..core.module import BaseModule
from ..core.module_utils import generate_id, get_timestamp
from ..core.message import Message
from ..utils.utils import safe_remove
class BaseMemory(BaseModule):
"""Base class for memory implementations in the EvoAgentX framework.
BaseMemory provides core functionality for storing, retrieving, and
filtering messages. It maintains a chronological list of messages while
also providing indices for efficient retrieval by action or workflow goal.
Attributes:
messages: List of stored Message objects.
memory_id: Unique identifier for this memory instance.
timestamp: Creation timestamp of this memory instance.
capacity: Maximum number of messages that can be stored, or None for unlimited.
"""
messages: List[Message] = Field(default_factory=list)
memory_id: str = Field(default_factory=generate_id)
timestamp: str = Field(default_factory=get_timestamp)
capacity: Optional[PositiveInt] = Field(default=None, description="maximum of messages, None means there is no limit to the message number")
def init_module(self):
"""Initialize memory indices.
Creates default dictionaries for indexing messages by action and workflow goal.
"""
self._by_action = defaultdict(list)
self._by_wf_goal = defaultdict(list)
@property
def size(self) -> int:
"""Returns the current number of messages in memory.
Returns:
int: Number of messages currently stored.
"""
return len(self.messages)
def clear(self):
"""Clear all messages from memory.
Removes all messages and resets all indices.
"""
self.messages.clear()
self._by_action.clear()
self._by_wf_goal.clear()
def remove_message(self, message: Message):
"""Remove a single message from memory.
Removes the specified message from the main message list and all indices.
If the message is not found in memory, no action is taken.
Args:
message: The message to be removed. The message will be removed from
self.messages, self._by_action, and self._by_wf_goal.
"""
if not message:
return
if message not in self.messages:
return
safe_remove(self.messages, message)
if self._by_action and not message.action:
safe_remove(self._by_action[message.action], message)
if self._by_wf_goal and not message.wf_goal:
safe_remove(self._by_wf_goal[message.wf_goal], message)
def add_message(self, message: Message):
"""Store a single message in memory.
Adds the message to the main list and relevant indices if it's not already stored.
Args:
message (Message): the message to be stored.
"""
if not message:
return
if message in self.messages:
return
self.messages.append(message)
if self._by_action and not message.action:
self._by_action[message.action].append(message)
if self._by_wf_goal and not message.wf_goal:
self._by_wf_goal[message.wf_goal].append(message)
def add_messages(self, messages: Union[Message, List[Message]], **kwargs):
"""
store (a) message(s) to the memory.
Args:
messages (Union[Message, List[Message]]): the input messages can be a single message or a list of message.
"""
if not isinstance(messages, list):
messages = [messages]
for message in messages:
self.add_message(message)
def get(self, n: int=None, **kwargs) -> List[Message]:
"""Retrieve recent messages from memory.
Returns the most recent messages, up to the specified limit.
Args:
n: The maximum number of messages to return. If None, returns all messages.
**kwargs (Any): Additional parameters (unused in base implementation).
Returns:
A list of Message objects, ordered from oldest to newest.
Raises:
AssertionError: If n is negative.
"""
assert n is None or n>=0, "n must be None or a positive int"
messages = self.messages if n is None else self.messages[-n:]
return messages
def get_by_type(self, data: Dict[str, list], key: str, n: int = None, **kwargs) -> List[Message]:
"""
Retrieve a list of Message objects from a given data dictionary `data` based on a specified type `key`.
This function looks up the value associated with `key` in the `data` dictionary, which should be a list of messages. It then returns a subset of these messages according to the specified parameters.
If `n` is provided, it limits the number of messages returned; otherwise, it may return the entire list. Additional keyword arguments (**kwargs) can be used to further filter or process the resulting messages.
Args:
data (Dict[str, list]): A dictionary where keys are type strings and values are lists of messages.
key (str): The key in `data` identifying the specific list of messages to retrieve.
n (int, optional): The maximum number of messages to return. If not provided, all messages under the given `key` may be returned.
**kwargs (Any): Additional parameters for filtering or processing the messages.
Returns:
List[Message]: A list of messages corresponding to the given `key`, possibly filtered or truncated according to `n` and other provided keyword arguments.
"""
if not data or key not in data:
return []
assert n is None or n>=0, "n must be None or a positive int"
messages = data[key] if n is None else data[key][-n:]
return messages
def get_by_action(self, actions: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]:
"""
return messages triggered by `actions` in the memory.
Args:
actions: A single action name or list of action names to filter by.
n: Maximum number of messages to return per action. If None, returns all matching messages.
**kwargs (Any): Additional parameters (unused in base implementation).
Returns:
A list of Message objects, sorted by timestamp.
"""
if isinstance(actions, str):
actions = [actions]
messages = []
for action in actions:
messages.extend(self.get_by_type(self._by_action, key=action, n=n, **kwargs))
messages = Message.sort_by_timestamp(messages)
return messages
def get_by_wf_goal(self, wf_goals: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]:
"""
return messages related to `wf_goals` in the memory.
Args:
wf_goals: A single workflow goal or list of workflow goals to filter by.
n: Maximum number of messages to return per workflow goal. If None, returns all matching messages.
**kwargs (Any): Additional parameters (unused in base implementation).
Returns:
A list of Message objects, sorted by timestamp.
"""
if isinstance(wf_goals, str):
wf_goals = [wf_goals]
messages = []
for wf_goal in wf_goals:
messages.append(self.get_by_type(self._by_wf_goal, key=wf_goal, n=n, **kwargs))
messages = Message.sort_by_timestamp(messages)
return messages
class ShortTermMemory(BaseModule):
"""
Short-term memory implementation.
Stores only the most recent N messages (like a sliding window).
Unlike BaseMemory/LongTermMemory, this is purely in-memory cache
and does not persist to storage_handler or vector DB.
Attributes:
buffer: Internal deque holding Message objects, capped at max_size.
max_size: Maximum number of messages to retain.
memory_id: Unique identifier for this memory instance.
timestamp: Creation timestamp.
"""
buffer: List[Message] = Field(default_factory=list, exclude=True)
max_size: PositiveInt = Field(default=5, description="Maximum number of messages to keep in short-term memory")
memory_id: str = Field(default_factory=generate_id)
timestamp: str = Field(default_factory=get_timestamp)
@field_validator("buffer", mode="before")
@classmethod
def ensure_list(cls, v):
"""Ensure that the buffer is always a list, even if it is null in the JSON."""
if v is None:
return []
return v
# Convert to deque during initialization
def model_post_init(self, __context=None):
"""
Pydantic V2 hook after model initialization.
Convert buffer list → deque, enforce max_size.
"""
self.buffer = deque(self.buffer, maxlen=self.max_size)
@property
def size(self) -> int:
"""Return current number of messages stored."""
return len(self.buffer)
def clear(self):
"""Clear all short-term memory."""
self.buffer.clear()
def add_message(self, message: Message):
"""Add a single message to short-term memory."""
if not message:
return
self.buffer.append(message)
def add_messages(self, messages: Union[Message, List[Message]]):
"""Add one or multiple messages."""
if not isinstance(messages, list):
messages = [messages]
for msg in messages:
self.add_message(msg)
def get(self, n: Optional[int] = None) -> List[Message]:
"""
Retrieve the most recent n messages (default: all).
Args:
n: Number of messages to return. If None, return all.
Returns:
List of Message objects, oldest → newest.
"""
if n is None:
return list(self.buffer)
return list(self.buffer)[-n:]
def get_last(self) -> Optional[Message]:
"""Return the latest message, or None if empty."""
return self.buffer[-1] if self.buffer else None