"""Unified memory manager providing access to all memory layers.""" from __future__ import annotations import logging from enum import Enum from typing import Any from pydantic import BaseModel, Field from app.config import Settings from app.memory.long_term import Document, LongTermMemory, SearchResult from app.memory.shared import Message, SharedMemory from app.memory.short_term import MemoryEntry, ShortTermMemory from app.memory.working import WorkingMemory, WorkingMemoryItem logger = logging.getLogger(__name__) class MemoryType(str, Enum): """Types of memory layers.""" SHORT_TERM = "short_term" WORKING = "working" LONG_TERM = "long_term" SHARED = "shared" class MemoryStats(BaseModel): """Statistics for all memory layers.""" short_term: dict[str, Any] = Field(default_factory=dict) working: dict[str, Any] = Field(default_factory=dict) long_term: dict[str, Any] = Field(default_factory=dict) shared: dict[str, Any] = Field(default_factory=dict) class MemoryManager: """ Unified interface to all memory layers. The MemoryManager provides a single entry point for interacting with different types of memory (short-term, working, long-term, shared). It handles initialization, coordination, and lifecycle management. Attributes: short_term: Episode-scoped dictionary memory. working: LRU-based reasoning scratch space. long_term: Persistent vector storage. shared: Multi-agent shared state. """ def __init__(self, settings: Settings) -> None: """ Initialize memory manager with settings. Args: settings: Application settings. """ self._settings = settings self._initialized = False # Initialize memory layers self.short_term = ShortTermMemory( max_size=settings.short_term_memory_size, ) self.working = WorkingMemory( capacity=settings.working_memory_size, ) self.long_term = LongTermMemory( collection_name=settings.chroma_collection_name, persist_directory=settings.chroma_persist_directory, top_k=settings.long_term_memory_top_k, ) self.shared = SharedMemory() async def initialize(self) -> None: """ Initialize all memory layers. This should be called during application startup. """ if self._initialized: return try: # Initialize long-term memory (ChromaDB) await self.long_term.initialize() self._initialized = True logger.info("Memory manager initialized successfully") except Exception as e: logger.error(f"Failed to initialize memory manager: {e}") raise async def shutdown(self) -> None: """ Shutdown all memory layers gracefully. This should be called during application shutdown. """ try: # Persist long-term memory await self.long_term.shutdown() # Clear working memory await self.working.clear() self._initialized = False logger.info("Memory manager shutdown complete") except Exception as e: logger.error(f"Error during memory manager shutdown: {e}") raise @property def is_initialized(self) -> bool: """Check if memory manager is initialized.""" return self._initialized # ========================================================================= # Unified Store Interface # ========================================================================= async def store( self, key: str, value: Any, memory_type: MemoryType = MemoryType.SHORT_TERM, **kwargs: Any, ) -> Any: """ Store a value in the specified memory layer. Args: key: Key or identifier for the stored value. value: Value to store. memory_type: Which memory layer to use. **kwargs: Additional arguments passed to the specific layer. Returns: The created entry/document (varies by memory type). Raises: ValueError: If memory_type is invalid. """ match memory_type: case MemoryType.SHORT_TERM: tags = kwargs.get("tags") return await self.short_term.set(key, value, tags=tags) case MemoryType.WORKING: priority = kwargs.get("priority", 0.0) metadata = kwargs.get("metadata") return await self.working.push( content=value, item_id=key, priority=priority, metadata=metadata, ) case MemoryType.LONG_TERM: if not isinstance(value, str): value = str(value) metadata = kwargs.get("metadata") embedding = kwargs.get("embedding") return await self.long_term.store( content=value, document_id=key, metadata=metadata, embedding=embedding, ) case MemoryType.SHARED: await self.shared.set_state(key, value) return value case _: raise ValueError(f"Invalid memory type: {memory_type}") # ========================================================================= # Unified Retrieve Interface # ========================================================================= async def retrieve( self, key: str, memory_type: MemoryType = MemoryType.SHORT_TERM, default: Any = None, ) -> Any: """ Retrieve a value from the specified memory layer. Args: key: Key or identifier to look up. memory_type: Which memory layer to query. default: Default value if not found. Returns: The stored value or default. Raises: ValueError: If memory_type is invalid. """ match memory_type: case MemoryType.SHORT_TERM: return await self.short_term.get(key, default=default) case MemoryType.WORKING: item = await self.working.peek_by_id(key) return item.content if item else default case MemoryType.LONG_TERM: doc = await self.long_term.get(key) return doc.content if doc else default case MemoryType.SHARED: return await self.shared.get_state(key, default=default) case _: raise ValueError(f"Invalid memory type: {memory_type}") # ========================================================================= # Unified Search Interface # ========================================================================= async def search( self, query: str, memory_type: MemoryType = MemoryType.LONG_TERM, top_k: int = 10, **kwargs: Any, ) -> list[Any]: """ Search for values in the specified memory layer. Args: query: Search query. memory_type: Which memory layer to search. top_k: Maximum number of results. **kwargs: Additional arguments for specific layers. Returns: List of matching entries/documents. Raises: ValueError: If memory_type is invalid or doesn't support search. """ match memory_type: case MemoryType.SHORT_TERM: # Search by tag or return all keys containing query tag = kwargs.get("tag") if tag: return list((await self.short_term.get_by_tag(tag)).items())[:top_k] keys = await self.short_term.list_keys() matching = [k for k in keys if query.lower() in k.lower()] results = [] for key in matching[:top_k]: value = await self.short_term.get(key) results.append((key, value)) return results case MemoryType.WORKING: # Search working memory items def matches(item: WorkingMemoryItem) -> bool: content_str = str(item.content).lower() return query.lower() in content_str items = await self.working.search(matches) return items[:top_k] case MemoryType.LONG_TERM: where = kwargs.get("where") query_embedding = kwargs.get("query_embedding") return await self.long_term.search( query=query, top_k=top_k, where=where, query_embedding=query_embedding, ) case MemoryType.SHARED: # Search state keys all_state = await self.shared.get_all_state() matching = [ (k, v) for k, v in all_state.items() if query.lower() in k.lower() or query.lower() in str(v).lower() ] return matching[:top_k] case _: raise ValueError(f"Invalid memory type: {memory_type}") # ========================================================================= # Unified Clear Interface # ========================================================================= async def clear( self, memory_type: MemoryType | None = None, ) -> dict[str, int]: """ Clear memory layers. Args: memory_type: Specific layer to clear, or None for all. Returns: Dictionary with counts of cleared items per layer. """ results: dict[str, int] = {} if memory_type is None or memory_type == MemoryType.SHORT_TERM: results["short_term"] = await self.short_term.clear() if memory_type is None or memory_type == MemoryType.WORKING: results["working"] = await self.working.clear() if memory_type is None or memory_type == MemoryType.LONG_TERM: results["long_term"] = await self.long_term.clear() if memory_type is None or memory_type == MemoryType.SHARED: shared_results = await self.shared.clear() results["shared_channels"] = shared_results["channels"] results["shared_state"] = shared_results["state_keys"] return results # ========================================================================= # Episode Management # ========================================================================= async def start_episode(self, episode_id: str) -> None: """ Start a new episode, clearing episode-scoped memory. Args: episode_id: Unique identifier for the episode. """ await self.short_term.set_episode(episode_id) await self.working.clear() logger.debug(f"Started episode: {episode_id}") async def end_episode(self) -> dict[str, int]: """ End the current episode, clearing temporary memory. Returns: Counts of cleared items. """ results = { "short_term": await self.short_term.clear(), "working": await self.working.clear(), } logger.debug(f"Ended episode: {results}") return results # ========================================================================= # Statistics # ========================================================================= async def get_stats(self) -> MemoryStats: """ Get statistics from all memory layers. Returns: MemoryStats with info from each layer. """ return MemoryStats( short_term=await self.short_term.get_stats(), working=await self.working.get_stats(), long_term=await self.long_term.get_stats(), shared=await self.shared.get_stats(), ) # ========================================================================= # Convenience Methods # ========================================================================= async def remember( self, content: str, metadata: dict[str, Any] | None = None, ) -> Document: """ Store content in long-term memory for later retrieval. This is a convenience method for storing knowledge. Args: content: Text content to remember. metadata: Optional metadata. Returns: The stored document. """ return await self.long_term.store(content=content, metadata=metadata) async def recall( self, query: str, top_k: int = 5, ) -> list[SearchResult]: """ Recall relevant memories based on a query. This is a convenience method for semantic search. Args: query: Search query. top_k: Number of results to return. Returns: List of relevant search results. """ return await self.long_term.search(query=query, top_k=top_k) async def think( self, thought: str, priority: float = 0.0, ) -> WorkingMemoryItem: """ Add a thought to working memory. This is a convenience method for reasoning steps. Args: thought: The thought content. priority: Priority score. Returns: The working memory item. """ return await self.working.push(content=thought, priority=priority) async def broadcast( self, channel: str, message: Any, sender: str | None = None, ) -> Message: """ Broadcast a message to a shared channel. This is a convenience method for multi-agent communication. Args: channel: Channel name. message: Message payload. sender: Optional sender identifier. Returns: The published message. """ return await self.shared.publish( channel=channel, payload=message, sender=sender, )