from __future__ import annotations """Context memory management for the agent.""" from dataclasses import dataclass, field from datetime import datetime from typing import Any @dataclass class MemoryEntry: """A single entry in memory.""" key: str value: Any timestamp: datetime = field(default_factory=datetime.now) source: str = "unknown" relevance: float = 1.0 class ContextMemory: """Manages context and working memory for the agent.""" def __init__(self, max_entries: int = 100): """Initialize memory. Args: max_entries: Maximum entries to keep """ self.max_entries = max_entries self._short_term: dict[str, MemoryEntry] = {} self._working: dict[str, Any] = {} self._conversation: list[dict[str, str]] = [] def store(self, key: str, value: Any, source: str = "agent") -> None: """Store a value in short-term memory. Args: key: Memory key value: Value to store source: Source of the information """ self._short_term[key] = MemoryEntry( key=key, value=value, source=source, ) # Trim if over capacity if len(self._short_term) > self.max_entries: self._trim_oldest() def retrieve(self, key: str) -> Any | None: """Retrieve a value from memory. Args: key: Memory key Returns: Stored value or None """ entry = self._short_term.get(key) return entry.value if entry else None def update_working(self, key: str, value: Any) -> None: """Update working memory. Args: key: Memory key value: Value to store """ self._working[key] = value def get_working(self, key: str, default: Any = None) -> Any: """Get from working memory. Args: key: Memory key default: Default value if not found Returns: Stored value or default """ return self._working.get(key, default) def add_conversation_turn(self, role: str, content: str) -> None: """Add a turn to conversation history. Args: role: Message role (user/assistant) content: Message content """ self._conversation.append({ "role": role, "content": content, "timestamp": datetime.now().isoformat(), }) def get_conversation_history(self, limit: int = 10) -> list[dict[str, str]]: """Get recent conversation history. Args: limit: Maximum turns to return Returns: List of conversation turns """ return self._conversation[-limit:] def get_context_summary(self) -> dict[str, Any]: """Get a summary of current context. Returns: Dictionary with context summary """ return { "short_term_keys": list(self._short_term.keys()), "working_memory_keys": list(self._working.keys()), "conversation_length": len(self._conversation), } def clear_working(self) -> None: """Clear working memory.""" self._working.clear() def clear_all(self) -> None: """Clear all memory.""" self._short_term.clear() self._working.clear() self._conversation.clear() def _trim_oldest(self) -> None: """Remove oldest entries to stay within capacity.""" if not self._short_term: return # Sort by timestamp and remove oldest sorted_keys = sorted( self._short_term.keys(), key=lambda k: self._short_term[k].timestamp, ) # Remove oldest 10% to_remove = max(1, len(sorted_keys) // 10) for key in sorted_keys[:to_remove]: del self._short_term[key] def search(self, query: str) -> list[MemoryEntry]: """Search memory for relevant entries. Args: query: Search query Returns: List of matching entries """ query_lower = query.lower() results = [] for entry in self._short_term.values(): # Simple keyword matching value_str = str(entry.value).lower() if query_lower in value_str or query_lower in entry.key.lower(): results.append(entry) # Sort by relevance (for now, just by timestamp) results.sort(key=lambda e: e.timestamp, reverse=True) return results