import json import os import time from typing import List, Dict, Optional, Tuple from datetime import datetime from sentence_transformers import SentenceTransformer import faiss import numpy as np from rich.console import Console console = Console() class MemoryManager: """Manages user memories using vector embeddings and FAISS for efficient retrieval.""" def __init__(self, memory_dir: str = "memories"): """ Initialize the memory manager. Args: memory_dir: Directory to store memory files """ self.memory_dir = memory_dir self.memories_file = os.path.join(memory_dir, "memories.json") self.timeline_file = os.path.join(memory_dir, "timeline.md") # Create memory directory if it doesn't exist os.makedirs(memory_dir, exist_ok=True) # Initialize sentence transformer for embeddings self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # Load existing memories or create empty list self.memories = self._load_memories() # Initialize FAISS index self.dimension = 384 # Dimension of MiniLM embeddings self.index = faiss.IndexFlatL2(self.dimension) self._build_index() def _load_memories(self) -> List[Dict]: """Load memories from JSON file.""" if os.path.exists(self.memories_file): try: with open(self.memories_file, 'r') as f: return json.load(f) except Exception as e: console.print(f"[red]Error loading memories: {e}[/red]") return [] return [] def _save_memories(self): """Save memories to JSON file.""" try: with open(self.memories_file, 'w') as f: json.dump(self.memories, f, indent=2) except Exception as e: console.print(f"[red]Error saving memories: {e}[/red]") def _build_index(self): """Build FAISS index from existing memories.""" if self.memories: embeddings = [] for memory in self.memories: if 'embedding' in memory: embeddings.append(np.array(memory['embedding'], dtype='float32')) if embeddings: embeddings = np.array(embeddings) self.index.add(embeddings) def add_memory(self, content: str, context: str = "", memory_type: str = "general") -> Dict: """ Add a new memory to the memory store. Args: content: The main content of the memory context: Additional context about when/where this occurred memory_type: Type of memory (general, fact, preference, etc.) Returns: The created memory object """ # Create embedding for the memory embedding = self.embedder.encode(content).astype('float32') # Create memory object memory = { "id": len(self.memories) + 1, "content": content, "context": context, "type": memory_type, "timestamp": datetime.now().isoformat(), "embedding": embedding.tolist(), "importance": self._calculate_importance(content) } # Add to memories list self.memories.append(memory) # Add to FAISS index self.index.add(embedding.reshape(1, -1)) # Save to file self._save_memories() # Update timeline self._update_timeline() console.print(f"[green]✓ Memory added: {content[:50]}...[/green]") return memory def _calculate_importance(self, content: str) -> float: """ Calculate the importance score of a memory based on its content. Args: content: The memory content Returns: Importance score between 0 and 1 """ # Simple importance calculation based on content length and keywords importance = 0.5 # Base importance # Keywords that indicate higher importance important_keywords = [ "love", "family", "important", "urgent", "must", "remember", "birthday", "anniversary", "special", "favorite", "hate", "never", "always", "often", "every", "daily", "weekly" ] content_lower = content.lower() for keyword in important_keywords: if keyword in content_lower: importance += 0.1 # Longer memories might be more important if len(content) > 100: importance += 0.1 return min(importance, 1.0) def retrieve_memories(self, query: str, k: int = 5) -> List[Dict]: """ Retrieve relevant memories based on a query. Args: query: The search query k: Number of memories to retrieve Returns: List of relevant memories sorted by relevance """ if not self.memories: return [] # Create embedding for the query query_embedding = self.embedder.encode(query).astype('float32') # Search in FAISS index distances, indices = self.index.search(query_embedding.reshape(1, -1), k) # Get the memories relevant_memories = [] for i, idx in enumerate(indices[0]): if idx < len(self.memories): memory = self.memories[idx] memory['relevance_score'] = float(distances[0][i]) relevant_memories.append(memory) # Sort by relevance (lower distance = more relevant) relevant_memories.sort(key=lambda x: x['relevance_score']) return relevant_memories def get_recent_memories(self, limit: int = 10) -> List[Dict]: """Get the most recent memories.""" return sorted(self.memories, key=lambda x: x['timestamp'], reverse=True)[:limit] def get_memory_types(self) -> Dict[str, int]: """Get statistics about memory types.""" type_counts = {} for memory in self.memories: memory_type = memory.get('type', 'general') type_counts[memory_type] = type_counts.get(memory_type, 0) + 1 return type_counts def _update_timeline(self): """Update the human-readable timeline file.""" timeline_content = "# Memory Timeline\n\n" # Group memories by date memories_by_date = {} for memory in sorted(self.memories, key=lambda x: x['timestamp'], reverse=True): date = memory['timestamp'][:10] # YYYY-MM-DD if date not in memories_by_date: memories_by_date[date] = [] memories_by_date[date].append(memory) # Build timeline for date in sorted(memories_by_date.keys(), reverse=True): timeline_content += f"## {date}\n\n" for memory in memories_by_date[date]: timeline_content += f"- **{memory['type'].title()}** ({memory['timestamp'][11:19]}): {memory['content']}\n" timeline_content += "\n" try: with open(self.timeline_file, 'w') as f: f.write(timeline_content) except Exception as e: console.print(f"[red]Error updating timeline: {e}[/red]") def get_summary(self) -> Dict: """Get a summary of the memory store.""" return { "total_memories": len(self.memories), "memory_types": self.get_memory_types(), "recent_memories": self.get_recent_memories(5), "timeline_file": self.timeline_file } def clear_memories(self): """Clear all memories.""" self.memories = [] self.index = faiss.IndexFlatL2(self.dimension) self._save_memories() self._update_timeline() console.print("[yellow]All memories cleared.[/yellow]")