Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| # Copyright (c) 2025 inclusionAI. | |
| import abc | |
| import asyncio | |
| import json | |
| import os | |
| from typing import Optional | |
| from aworld.config import ConfigDict | |
| from aworld.core.memory import MemoryBase, MemoryItem, MemoryStore, MemoryConfig | |
| from aworld.logs.util import logger | |
| from aworld.models.llm import get_llm_model, acall_llm_model | |
| class InMemoryMemoryStore(MemoryStore): | |
| def __init__(self): | |
| self.memory_items = [] | |
| def add(self, memory_item: MemoryItem): | |
| self.memory_items.append(memory_item) | |
| def get(self, memory_id) -> Optional[MemoryItem]: | |
| return next((item for item in self.memory_items if item.id == memory_id), None) | |
| def get_first(self, filters: dict = None) -> Optional[MemoryItem]: | |
| """Get the first memory item.""" | |
| filtered_items = self.get_all(filters) | |
| if len(filtered_items) == 0: | |
| return None | |
| return filtered_items[0] | |
| def total_rounds(self, filters: dict = None) -> int: | |
| """Get the total number of rounds.""" | |
| return len(self.get_all(filters)) | |
| def get_all(self, filters: dict = None) -> list[MemoryItem]: | |
| """Filter memory items based on filters.""" | |
| filtered_items = [item for item in self.memory_items if self._filter_memory_item(item, filters)] | |
| return filtered_items | |
| def _filter_memory_item(self, memory_item: MemoryItem, filters: dict = None) -> bool: | |
| if memory_item.deleted: | |
| return False | |
| if filters is None: | |
| return True | |
| if filters.get('user_id') is not None: | |
| if memory_item.metadata.get('user_id') is None: | |
| return False | |
| if memory_item.metadata.get('user_id') != filters['user_id']: | |
| return False | |
| if filters.get('agent_id') is not None: | |
| if memory_item.metadata.get('agent_id') is None: | |
| return False | |
| if memory_item.metadata.get('agent_id') != filters['agent_id']: | |
| return False | |
| if filters.get('task_id') is not None: | |
| if memory_item.metadata.get('task_id') is None: | |
| return False | |
| if memory_item.metadata.get('task_id') != filters['task_id']: | |
| return False | |
| if filters.get('session_id') is not None: | |
| if memory_item.metadata.get('session_id') is None: | |
| return False | |
| if memory_item.metadata.get('session_id') != filters['session_id']: | |
| return False | |
| if filters.get('memory_type') is not None: | |
| if memory_item.memory_type is None: | |
| return False | |
| if memory_item.memory_type != filters['memory_type']: | |
| return False | |
| return True | |
| def get_last_n(self, last_rounds, filters: dict = None) -> list[MemoryItem]: | |
| return self.memory_items[-last_rounds:] # Get the last n items | |
| def update(self, memory_item: MemoryItem): | |
| for index, item in enumerate(self.memory_items): | |
| if item.id == memory_item.id: | |
| self.memory_items[index] = memory_item # Update the item in the list | |
| break | |
| def delete(self, memory_id): | |
| exists = self.get(memory_id) | |
| if exists: | |
| exists.deleted = True | |
| def history(self, memory_id) -> list[MemoryItem] | None: | |
| exists = self.get(memory_id) | |
| if exists: | |
| return exists.histories | |
| return None | |
| class MemoryFactory: | |
| def from_config(cls, config: MemoryConfig) -> "MemoryBase": | |
| """ | |
| Initialize a Memory instance from a configuration dictionary. | |
| Args: | |
| config (dict): Configuration dictionary. | |
| Returns: | |
| InMemoryStorageMemory: Memory instance. | |
| """ | |
| if config.provider == "inmemory": | |
| return InMemoryStorageMemory( | |
| memory_store=InMemoryMemoryStore(), | |
| config=config, | |
| enable_summary=config.enable_summary, | |
| summary_rounds=config.summary_rounds | |
| ) | |
| elif config.provider == "mem0": | |
| from aworld.memory.mem0.mem0_memory import Mem0Memory | |
| return Mem0Memory( | |
| memory_store=InMemoryMemoryStore(), | |
| config=config | |
| ) | |
| else: | |
| raise ValueError(f"Invalid memory store type: {config.get('memory_store')}") | |
| class Memory(MemoryBase): | |
| __metaclass__ = abc.ABCMeta | |
| def __init__(self, memory_store: MemoryStore, config: MemoryConfig, **kwargs): | |
| self.memory_store = memory_store | |
| self.config = config | |
| self._llm_instance = None | |
| def default_llm_instance(self): | |
| def get_env(key: str, default_key: str, default_val: object=None): | |
| return os.getenv(key) if os.getenv(key) else os.getenv(default_key, default_val) | |
| if not self._llm_instance: | |
| self._llm_instance = get_llm_model(conf=ConfigDict({ | |
| "llm_model_name": get_env("MEM_LLM_MODEL_NAME", "LLM_MODEL_NAME"), | |
| "llm_api_key": get_env("MEM_LLM_API_KEY", "LLM_MODEL_NAME") , | |
| "llm_base_url": get_env("MEM_LLM_BASE_URL", 'LLM_BASE_URL'), | |
| "temperature": get_env("MEM_LLM_TEMPERATURE", "MEM_LLM_TEMPERATURE", 1.0), | |
| "streaming": 'False' | |
| })) | |
| return self._llm_instance | |
| def _build_history_context(self, messages) -> str: | |
| """Build the history context string from a list of messages. | |
| Args: | |
| messages: List of message objects with 'role', 'content', and optional 'tool_calls'. | |
| Returns: | |
| Concatenated context string. | |
| """ | |
| history_context = "" | |
| for item in messages: | |
| history_context += (f"\n\n{item['role']}: {item['content']}, " | |
| f"{'tool_calls:' + json.dumps(item['tool_calls']) if 'tool_calls' in item and item['tool_calls'] else ''}") | |
| return history_context | |
| async def _call_llm_summary(self, summary_messages: list) -> str: | |
| """Call LLM to generate summary and log the process. | |
| Args: | |
| summary_messages: List of messages to send to LLM. | |
| Returns: | |
| Summary content string. | |
| """ | |
| logger.info(f"🤔 [Summary] Creating summary memory, history messages: {summary_messages}") | |
| llm_response = await acall_llm_model( | |
| self.default_llm_instance, | |
| messages=summary_messages, | |
| stream=False | |
| ) | |
| logger.info(f'🤔 [Summary] summary_content: result is {llm_response.content[:400] + "...truncated"} ') | |
| return llm_response.content | |
| def _get_parsed_history_messages(self, history_items: list[MemoryItem]) -> list[dict]: | |
| """Get and format history messages for summary. | |
| Args: | |
| history_items: list[MemoryItem] | |
| Returns: | |
| List of parsed message dicts | |
| """ | |
| parsed_messages = [ | |
| { | |
| 'role': message.metadata['role'], | |
| 'content': message.content, | |
| 'tool_calls': message.metadata.get('tool_calls') if message.metadata.get('tool_calls') else None | |
| } | |
| for message in history_items] | |
| return parsed_messages | |
| async def async_gen_multi_rounds_summary(self, to_be_summary: list[MemoryItem]) -> str: | |
| logger.info( | |
| f"🤔 [Summary] Creating summary memory, history messages") | |
| if len(to_be_summary) == 0: | |
| return "" | |
| parsed_messages = self._get_parsed_history_messages(to_be_summary) | |
| history_context = self._build_history_context(parsed_messages) | |
| summary_messages = [ | |
| {"role": "user", "content": self.config.summary_prompt.format(context=history_context)} | |
| ] | |
| return await self._call_llm_summary(summary_messages) | |
| async def async_gen_summary(self, filters: dict, last_rounds: int) -> str: | |
| """A tool for summarizing the conversation history.""" | |
| logger.info(f"🤔 [Summary] Creating summary memory, history messages [filters -> {filters}, " | |
| f"last_rounds -> {last_rounds}]") | |
| history_items = self.memory_store.get_last_n(last_rounds, filters=filters) | |
| if len(history_items) == 0: | |
| return "" | |
| parsed_messages = self._get_parsed_history_messages(history_items) | |
| history_context = self._build_history_context(parsed_messages) | |
| summary_messages = [ | |
| {"role": "user", "content": self.config.summary_prompt.format(context=history_context)} | |
| ] | |
| return await self._call_llm_summary(summary_messages) | |
| async def async_gen_cur_round_summary(self, to_be_summary: MemoryItem, filters: dict, last_rounds: int) -> str: | |
| if self.config.enable_summary and len(to_be_summary.content) < self.config.summary_single_context_length: | |
| return to_be_summary.content | |
| logger.info(f"🤔 [Summary] Creating summary memory, history messages [filters -> {filters}, " | |
| f"last_rounds -> {last_rounds}]: to be summary content is {to_be_summary.content}") | |
| history_items = self.memory_store.get_last_n(last_rounds, filters=filters) | |
| if len(history_items) == 0: | |
| return "" | |
| parsed_messages = self._get_parsed_history_messages(history_items) | |
| # Append the to_be_summary | |
| parsed_messages.append({ | |
| "role": to_be_summary.metadata['role'], | |
| "content": f"{to_be_summary.content}", | |
| 'tool_call_id': to_be_summary.metadata['tool_call_id'], | |
| }) | |
| history_context = self._build_history_context(parsed_messages) | |
| summary_messages = [ | |
| {"role": "user", "content": self.config.summary_prompt.format(context=history_context)} | |
| ] | |
| return await self._call_llm_summary(summary_messages) | |
| def search(self, query, limit=100, filters=None) -> Optional[list[MemoryItem]]: | |
| pass | |
| class InMemoryStorageMemory(Memory): | |
| def __init__(self, memory_store: MemoryStore, config: MemoryConfig, enable_summary: bool = True, **kwargs): | |
| super().__init__(memory_store=memory_store, config=config) | |
| self.summary = {} | |
| self.summary_rounds = self.config.summary_rounds | |
| self.enable_summary = self.config.enable_summary | |
| def add(self, memory_item: MemoryItem, filters: dict = None): | |
| self.memory_store.add(memory_item) | |
| # Check if we need to create or update summary | |
| if self.enable_summary: | |
| total_rounds = len(self.memory_store.get_all()) | |
| if total_rounds > self.summary_rounds: | |
| self._create_or_update_summary(total_rounds) | |
| def _create_or_update_summary(self, total_rounds: int): | |
| """Create or update summary based on current total rounds. | |
| Args: | |
| total_rounds (int): Total number of rounds. | |
| """ | |
| summary_index = int(total_rounds / self.summary_rounds) | |
| start = (summary_index - 1) * self.summary_rounds | |
| end = total_rounds - self.summary_rounds | |
| # Ensure we have valid start and end indices | |
| start = max(0, start) | |
| end = max(start, end) | |
| # Get the memory items to summarize | |
| items_to_summarize = self.memory_store.get_all()[start:end + 1] | |
| print(f"{total_rounds}start: {start}, end: {end},") | |
| # Create summary content | |
| summary_content = self._summarize_items(items_to_summarize, summary_index) | |
| # Create the range key | |
| range_key = f"{start}_{end}" | |
| # Check if summary for this range already exists | |
| if range_key in self.summary: | |
| # Update existing summary | |
| self.summary[range_key].content = summary_content | |
| self.summary[range_key].updated_at = None # This will update the timestamp | |
| else: | |
| # Create new summary | |
| summary_item = MemoryItem( | |
| content=summary_content, | |
| metadata={ | |
| "summary_index": summary_index, | |
| "start_round": start, | |
| "end_round": end, | |
| "role": "system" | |
| }, | |
| tags=["summary"] | |
| ) | |
| self.summary[range_key] = summary_item | |
| def _summarize_items(self, items: list[MemoryItem], summary_index: int) -> str: | |
| """Summarize a list of memory items. | |
| Args: | |
| items (list[MemoryItem]): List of memory items to summarize. | |
| summary_index (int): Summary index. | |
| Returns: | |
| str: Summary content. | |
| """ | |
| # This is a placeholder. In a real implementation, you might use an LLM or other method | |
| # to create a meaningful summary of the content | |
| return asyncio.run(self.async_gen_multi_rounds_summary(items)) | |
| def update(self, memory_item: MemoryItem): | |
| self.memory_store.update(memory_item) | |
| def delete(self, memory_id): | |
| self.memory_store.delete(memory_id) | |
| def get(self, memory_id) -> Optional[MemoryItem]: | |
| return self.memory_store.get(memory_id) | |
| def get_all(self, filters: dict = None) -> list[MemoryItem]: | |
| return self.memory_store.get_all() | |
| def get_last_n(self, last_rounds, add_first_message=True, filters: dict = None) -> list[MemoryItem]: | |
| """Get last n memories. | |
| Args: | |
| last_rounds (int): Number of memories to retrieve. | |
| add_first_message (bool): | |
| Returns: | |
| list[MemoryItem]: List of latest memories. | |
| """ | |
| memory_items = self.memory_store.get_last_n(last_rounds) | |
| while len(memory_items) > 0 and memory_items[0].metadata and "tool_call_id" in memory_items[0].metadata and \ | |
| memory_items[0].metadata["tool_call_id"]: | |
| last_rounds = last_rounds + 1 | |
| memory_items = self.memory_store.get_last_n(last_rounds) | |
| # If summary is disabled or no summaries exist, return just the last_n_items | |
| if not self.enable_summary or not self.summary: | |
| return memory_items | |
| # Calculate the range for relevant summaries | |
| all_items = self.memory_store.get_all() | |
| total_items = len(all_items) | |
| end_index = total_items - last_rounds | |
| # Get complete summaries | |
| result = [] | |
| complete_summary_count = end_index // self.summary_rounds | |
| # Get complete summaries | |
| for i in range(complete_summary_count): | |
| range_key = f"{i * self.summary_rounds}_{(i + 1) * self.summary_rounds - 1}" | |
| if range_key in self.summary: | |
| result.append(self.summary[range_key]) | |
| # Get the last incomplete summary if exists | |
| remaining_items = end_index % self.summary_rounds | |
| if remaining_items > 0: | |
| start = complete_summary_count * self.summary_rounds | |
| range_key = f"{start}_{end_index - 1}" | |
| if range_key in self.summary: | |
| result.append(self.summary[range_key]) | |
| # Add the last n items | |
| result.extend(memory_items) | |
| # Add first user input | |
| if add_first_message and last_rounds < self.memory_store.total_rounds(): | |
| memory_items.insert(0, self.memory_store.get_first()) | |
| return result | |