memorychat / memory_manager.py
artecnosomatic's picture
Initial commit: Memory Chat application
0919d5b
import json
import os
import time
from typing import List, Dict, Optional, Tuple
from datetime import datetime
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from rich.console import Console
console = Console()
class MemoryManager:
"""Manages user memories using vector embeddings and FAISS for efficient retrieval."""
def __init__(self, memory_dir: str = "memories"):
"""
Initialize the memory manager.
Args:
memory_dir: Directory to store memory files
"""
self.memory_dir = memory_dir
self.memories_file = os.path.join(memory_dir, "memories.json")
self.timeline_file = os.path.join(memory_dir, "timeline.md")
# Create memory directory if it doesn't exist
os.makedirs(memory_dir, exist_ok=True)
# Initialize sentence transformer for embeddings
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
# Load existing memories or create empty list
self.memories = self._load_memories()
# Initialize FAISS index
self.dimension = 384 # Dimension of MiniLM embeddings
self.index = faiss.IndexFlatL2(self.dimension)
self._build_index()
def _load_memories(self) -> List[Dict]:
"""Load memories from JSON file."""
if os.path.exists(self.memories_file):
try:
with open(self.memories_file, 'r') as f:
return json.load(f)
except Exception as e:
console.print(f"[red]Error loading memories: {e}[/red]")
return []
return []
def _save_memories(self):
"""Save memories to JSON file."""
try:
with open(self.memories_file, 'w') as f:
json.dump(self.memories, f, indent=2)
except Exception as e:
console.print(f"[red]Error saving memories: {e}[/red]")
def _build_index(self):
"""Build FAISS index from existing memories."""
if self.memories:
embeddings = []
for memory in self.memories:
if 'embedding' in memory:
embeddings.append(np.array(memory['embedding'], dtype='float32'))
if embeddings:
embeddings = np.array(embeddings)
self.index.add(embeddings)
def add_memory(self, content: str, context: str = "", memory_type: str = "general") -> Dict:
"""
Add a new memory to the memory store.
Args:
content: The main content of the memory
context: Additional context about when/where this occurred
memory_type: Type of memory (general, fact, preference, etc.)
Returns:
The created memory object
"""
# Create embedding for the memory
embedding = self.embedder.encode(content).astype('float32')
# Create memory object
memory = {
"id": len(self.memories) + 1,
"content": content,
"context": context,
"type": memory_type,
"timestamp": datetime.now().isoformat(),
"embedding": embedding.tolist(),
"importance": self._calculate_importance(content)
}
# Add to memories list
self.memories.append(memory)
# Add to FAISS index
self.index.add(embedding.reshape(1, -1))
# Save to file
self._save_memories()
# Update timeline
self._update_timeline()
console.print(f"[green]✓ Memory added: {content[:50]}...[/green]")
return memory
def _calculate_importance(self, content: str) -> float:
"""
Calculate the importance score of a memory based on its content.
Args:
content: The memory content
Returns:
Importance score between 0 and 1
"""
# Simple importance calculation based on content length and keywords
importance = 0.5 # Base importance
# Keywords that indicate higher importance
important_keywords = [
"love", "family", "important", "urgent", "must", "remember",
"birthday", "anniversary", "special", "favorite", "hate",
"never", "always", "often", "every", "daily", "weekly"
]
content_lower = content.lower()
for keyword in important_keywords:
if keyword in content_lower:
importance += 0.1
# Longer memories might be more important
if len(content) > 100:
importance += 0.1
return min(importance, 1.0)
def retrieve_memories(self, query: str, k: int = 5) -> List[Dict]:
"""
Retrieve relevant memories based on a query.
Args:
query: The search query
k: Number of memories to retrieve
Returns:
List of relevant memories sorted by relevance
"""
if not self.memories:
return []
# Create embedding for the query
query_embedding = self.embedder.encode(query).astype('float32')
# Search in FAISS index
distances, indices = self.index.search(query_embedding.reshape(1, -1), k)
# Get the memories
relevant_memories = []
for i, idx in enumerate(indices[0]):
if idx < len(self.memories):
memory = self.memories[idx]
memory['relevance_score'] = float(distances[0][i])
relevant_memories.append(memory)
# Sort by relevance (lower distance = more relevant)
relevant_memories.sort(key=lambda x: x['relevance_score'])
return relevant_memories
def get_recent_memories(self, limit: int = 10) -> List[Dict]:
"""Get the most recent memories."""
return sorted(self.memories, key=lambda x: x['timestamp'], reverse=True)[:limit]
def get_memory_types(self) -> Dict[str, int]:
"""Get statistics about memory types."""
type_counts = {}
for memory in self.memories:
memory_type = memory.get('type', 'general')
type_counts[memory_type] = type_counts.get(memory_type, 0) + 1
return type_counts
def _update_timeline(self):
"""Update the human-readable timeline file."""
timeline_content = "# Memory Timeline\n\n"
# Group memories by date
memories_by_date = {}
for memory in sorted(self.memories, key=lambda x: x['timestamp'], reverse=True):
date = memory['timestamp'][:10] # YYYY-MM-DD
if date not in memories_by_date:
memories_by_date[date] = []
memories_by_date[date].append(memory)
# Build timeline
for date in sorted(memories_by_date.keys(), reverse=True):
timeline_content += f"## {date}\n\n"
for memory in memories_by_date[date]:
timeline_content += f"- **{memory['type'].title()}** ({memory['timestamp'][11:19]}): {memory['content']}\n"
timeline_content += "\n"
try:
with open(self.timeline_file, 'w') as f:
f.write(timeline_content)
except Exception as e:
console.print(f"[red]Error updating timeline: {e}[/red]")
def get_summary(self) -> Dict:
"""Get a summary of the memory store."""
return {
"total_memories": len(self.memories),
"memory_types": self.get_memory_types(),
"recent_memories": self.get_recent_memories(5),
"timeline_file": self.timeline_file
}
def clear_memories(self):
"""Clear all memories."""
self.memories = []
self.index = faiss.IndexFlatL2(self.dimension)
self._save_memories()
self._update_timeline()
console.print("[yellow]All memories cleared.[/yellow]")