from collections import defaultdict from typing import Union, Optional, List, Dict from collections import deque from pydantic import Field, PositiveInt, field_validator from ..core.module import BaseModule from ..core.module_utils import generate_id, get_timestamp from ..core.message import Message from ..utils.utils import safe_remove class BaseMemory(BaseModule): """Base class for memory implementations in the EvoAgentX framework. BaseMemory provides core functionality for storing, retrieving, and filtering messages. It maintains a chronological list of messages while also providing indices for efficient retrieval by action or workflow goal. Attributes: messages: List of stored Message objects. memory_id: Unique identifier for this memory instance. timestamp: Creation timestamp of this memory instance. capacity: Maximum number of messages that can be stored, or None for unlimited. """ messages: List[Message] = Field(default_factory=list) memory_id: str = Field(default_factory=generate_id) timestamp: str = Field(default_factory=get_timestamp) capacity: Optional[PositiveInt] = Field(default=None, description="maximum of messages, None means there is no limit to the message number") def init_module(self): """Initialize memory indices. Creates default dictionaries for indexing messages by action and workflow goal. """ self._by_action = defaultdict(list) self._by_wf_goal = defaultdict(list) @property def size(self) -> int: """Returns the current number of messages in memory. Returns: int: Number of messages currently stored. """ return len(self.messages) def clear(self): """Clear all messages from memory. Removes all messages and resets all indices. """ self.messages.clear() self._by_action.clear() self._by_wf_goal.clear() def remove_message(self, message: Message): """Remove a single message from memory. Removes the specified message from the main message list and all indices. If the message is not found in memory, no action is taken. Args: message: The message to be removed. The message will be removed from self.messages, self._by_action, and self._by_wf_goal. """ if not message: return if message not in self.messages: return safe_remove(self.messages, message) if self._by_action and not message.action: safe_remove(self._by_action[message.action], message) if self._by_wf_goal and not message.wf_goal: safe_remove(self._by_wf_goal[message.wf_goal], message) def add_message(self, message: Message): """Store a single message in memory. Adds the message to the main list and relevant indices if it's not already stored. Args: message (Message): the message to be stored. """ if not message: return if message in self.messages: return self.messages.append(message) if self._by_action and not message.action: self._by_action[message.action].append(message) if self._by_wf_goal and not message.wf_goal: self._by_wf_goal[message.wf_goal].append(message) def add_messages(self, messages: Union[Message, List[Message]], **kwargs): """ store (a) message(s) to the memory. Args: messages (Union[Message, List[Message]]): the input messages can be a single message or a list of message. """ if not isinstance(messages, list): messages = [messages] for message in messages: self.add_message(message) def get(self, n: int=None, **kwargs) -> List[Message]: """Retrieve recent messages from memory. Returns the most recent messages, up to the specified limit. Args: n: The maximum number of messages to return. If None, returns all messages. **kwargs (Any): Additional parameters (unused in base implementation). Returns: A list of Message objects, ordered from oldest to newest. Raises: AssertionError: If n is negative. """ assert n is None or n>=0, "n must be None or a positive int" messages = self.messages if n is None else self.messages[-n:] return messages def get_by_type(self, data: Dict[str, list], key: str, n: int = None, **kwargs) -> List[Message]: """ Retrieve a list of Message objects from a given data dictionary `data` based on a specified type `key`. This function looks up the value associated with `key` in the `data` dictionary, which should be a list of messages. It then returns a subset of these messages according to the specified parameters. If `n` is provided, it limits the number of messages returned; otherwise, it may return the entire list. Additional keyword arguments (**kwargs) can be used to further filter or process the resulting messages. Args: data (Dict[str, list]): A dictionary where keys are type strings and values are lists of messages. key (str): The key in `data` identifying the specific list of messages to retrieve. n (int, optional): The maximum number of messages to return. If not provided, all messages under the given `key` may be returned. **kwargs (Any): Additional parameters for filtering or processing the messages. Returns: List[Message]: A list of messages corresponding to the given `key`, possibly filtered or truncated according to `n` and other provided keyword arguments. """ if not data or key not in data: return [] assert n is None or n>=0, "n must be None or a positive int" messages = data[key] if n is None else data[key][-n:] return messages def get_by_action(self, actions: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]: """ return messages triggered by `actions` in the memory. Args: actions: A single action name or list of action names to filter by. n: Maximum number of messages to return per action. If None, returns all matching messages. **kwargs (Any): Additional parameters (unused in base implementation). Returns: A list of Message objects, sorted by timestamp. """ if isinstance(actions, str): actions = [actions] messages = [] for action in actions: messages.extend(self.get_by_type(self._by_action, key=action, n=n, **kwargs)) messages = Message.sort_by_timestamp(messages) return messages def get_by_wf_goal(self, wf_goals: Union[str, List[str]], n: int=None, **kwargs) -> List[Message]: """ return messages related to `wf_goals` in the memory. Args: wf_goals: A single workflow goal or list of workflow goals to filter by. n: Maximum number of messages to return per workflow goal. If None, returns all matching messages. **kwargs (Any): Additional parameters (unused in base implementation). Returns: A list of Message objects, sorted by timestamp. """ if isinstance(wf_goals, str): wf_goals = [wf_goals] messages = [] for wf_goal in wf_goals: messages.append(self.get_by_type(self._by_wf_goal, key=wf_goal, n=n, **kwargs)) messages = Message.sort_by_timestamp(messages) return messages class ShortTermMemory(BaseModule): """ Short-term memory implementation. Stores only the most recent N messages (like a sliding window). Unlike BaseMemory/LongTermMemory, this is purely in-memory cache and does not persist to storage_handler or vector DB. Attributes: buffer: Internal deque holding Message objects, capped at max_size. max_size: Maximum number of messages to retain. memory_id: Unique identifier for this memory instance. timestamp: Creation timestamp. """ buffer: List[Message] = Field(default_factory=list, exclude=True) max_size: PositiveInt = Field(default=5, description="Maximum number of messages to keep in short-term memory") memory_id: str = Field(default_factory=generate_id) timestamp: str = Field(default_factory=get_timestamp) @field_validator("buffer", mode="before") @classmethod def ensure_list(cls, v): """Ensure that the buffer is always a list, even if it is null in the JSON.""" if v is None: return [] return v # Convert to deque during initialization def model_post_init(self, __context=None): """ Pydantic V2 hook after model initialization. Convert buffer list → deque, enforce max_size. """ self.buffer = deque(self.buffer, maxlen=self.max_size) @property def size(self) -> int: """Return current number of messages stored.""" return len(self.buffer) def clear(self): """Clear all short-term memory.""" self.buffer.clear() def add_message(self, message: Message): """Add a single message to short-term memory.""" if not message: return self.buffer.append(message) def add_messages(self, messages: Union[Message, List[Message]]): """Add one or multiple messages.""" if not isinstance(messages, list): messages = [messages] for msg in messages: self.add_message(msg) def get(self, n: Optional[int] = None) -> List[Message]: """ Retrieve the most recent n messages (default: all). Args: n: Number of messages to return. If None, return all. Returns: List of Message objects, oldest → newest. """ if n is None: return list(self.buffer) return list(self.buffer)[-n:] def get_last(self) -> Optional[Message]: """Return the latest message, or None if empty.""" return self.buffer[-1] if self.buffer else None