virtual-me-agent / memory_system.py
camdog920's picture
Upload memory_system.py
2881c16 verified
"""
Long-Term Memory System - Virtual Me Agent
Uses ChromaDB for vector-based memory retrieval of past conversations.
"""
import os
import json
import hashlib
from datetime import datetime
from typing import List, Dict, Optional, Any
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
class MemoryManager:
"""Manages long-term memory using vector embeddings for semantic retrieval."""
def __init__(self,
user_id: str = "default",
db_path: str = "./memory_db",
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
self.user_id = user_id
self.db_path = db_path
self.embedding_model_name = embedding_model
self.embedding_model = None
# Initialize ChromaDB
os.makedirs(db_path, exist_ok=True)
self.client = chromadb.PersistentClient(
path=db_path,
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection for this user
self.collection = self.client.get_or_create_collection(
name=f"virtual_me_{user_id}",
metadata={"hnsw:space": "cosine"}
)
self._load_embedding_model()
def _load_embedding_model(self):
"""Lazy-load the embedding model."""
if self.embedding_model is None:
print(f"Loading embedding model: {self.embedding_model_name}")
self.embedding_model = SentenceTransformer(self.embedding_model_name)
print("Embedding model loaded.")
def _embed(self, texts: List[str]) -> List[List[float]]:
"""Embed texts into vectors."""
self._load_embedding_model()
return self.embedding_model.encode(texts, convert_to_numpy=True).tolist()
def _generate_id(self, text: str, timestamp: str = None) -> str:
"""Generate a unique ID for a memory."""
content = f"{text}_{timestamp or datetime.now().isoformat()}"
return hashlib.md5(content.encode()).hexdigest()
def add_memory(self,
content: str,
metadata: Optional[Dict[str, Any]] = None,
memory_type: str = "conversation") -> str:
"""
Add a new memory to long-term storage.
Args:
content: The text to remember
metadata: Additional info (timestamp, type, etc.)
memory_type: Category of memory (conversation, fact, preference, etc.)
Returns:
The memory ID
"""
timestamp = metadata.get("timestamp", datetime.now().isoformat()) if metadata else datetime.now().isoformat()
memory_id = self._generate_id(content, timestamp)
# Prepare metadata
full_metadata = {
"timestamp": timestamp,
"type": memory_type,
"content_preview": content[:100],
}
if metadata:
# Filter to only allow serializable simple types
for k, v in metadata.items():
if isinstance(v, (str, int, float, bool)):
full_metadata[k] = v
# Generate embedding and store
embeddings = self._embed([content])
self.collection.add(
ids=[memory_id],
embeddings=embeddings,
documents=[content],
metadatas=[full_metadata]
)
return memory_id
def get_relevant_memories(self,
query: str,
n_results: int = 5,
memory_type: Optional[str] = None) -> List[Dict]:
"""
Retrieve memories relevant to a query.
Args:
query: Text to find relevant memories for
n_results: Number of memories to retrieve
memory_type: Filter by memory type
Returns:
List of memory dicts with content, metadata, and distance
"""
query_embedding = self._embed([query])
where_filter = None
if memory_type:
where_filter = {"type": memory_type}
results = self.collection.query(
query_embeddings=query_embedding,
n_results=min(n_results, self.collection.count()),
where=where_filter,
include=["documents", "metadatas", "distances"]
)
memories = []
if results["ids"] and results["ids"][0]:
for i, memory_id in enumerate(results["ids"][0]):
memories.append({
"id": memory_id,
"content": results["documents"][0][i] if results["documents"] else "",
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if results["distances"] else 1.0
})
return memories
def get_recent_memories(self, n_results: int = 10) -> List[Dict]:
"""Get most recent memories by timestamp."""
# ChromaDB doesn't have a native sort, so we get all and sort
all_results = self.collection.get(include=["documents", "metadatas"])
if not all_results["ids"]:
return []
memories = []
for i, memory_id in enumerate(all_results["ids"]):
memories.append({
"id": memory_id,
"content": all_results["documents"][i] if all_results["documents"] else "",
"metadata": all_results["metadatas"][i] if all_results["metadatas"] else {},
})
# Sort by timestamp descending
memories.sort(
key=lambda x: x["metadata"].get("timestamp", ""),
reverse=True
)
return memories[:n_results]
def get_memory_summary(self) -> Dict:
"""Get a summary of stored memories."""
count = self.collection.count()
all_data = self.collection.get(include=["metadatas"])
type_counts = {}
if all_data["metadatas"]:
for meta in all_data["metadatas"]:
t = meta.get("type", "unknown")
type_counts[t] = type_counts.get(t, 0) + 1
return {
"total_memories": count,
"type_distribution": type_counts,
"user_id": self.user_id
}
def search_by_type(self, memory_type: str, n_results: int = 50) -> List[Dict]:
"""Search memories by type."""
results = self.collection.get(
where={"type": memory_type},
include=["documents", "metadatas"]
)
memories = []
if results["ids"]:
for i, memory_id in enumerate(results["ids"]):
memories.append({
"id": memory_id,
"content": results["documents"][i] if results["documents"] else "",
"metadata": results["metadatas"][i] if results["metadatas"] else {},
})
return memories[:n_results]
def delete_memory(self, memory_id: str) -> bool:
"""Delete a specific memory."""
try:
self.collection.delete(ids=[memory_id])
return True
except Exception as e:
print(f"Error deleting memory: {e}")
return False
def clear_all_memories(self):
"""Clear all memories for this user."""
self.client.delete_collection(name=f"virtual_me_{self.user_id}")
self.collection = self.client.get_or_create_collection(
name=f"virtual_me_{self.user_id}",
metadata={"hnsw:space": "cosine"}
)
class InteractionTracker:
"""Tracks interaction patterns for learning and improvement."""
def __init__(self, memory_manager: MemoryManager):
self.memory = memory_manager
self.interaction_stats = {
"total_conversations": 0,
"total_messages": 0,
"avg_response_time": 0.0,
"user_satisfaction_scores": [],
"topics_discussed": set(),
"emotional_patterns": []
}
def log_interaction(self,
user_message: str,
assistant_response: str,
duration_ms: float = None,
user_feedback: str = None):
"""Log a single interaction for learning."""
self.interaction_stats["total_messages"] += 1
# Extract topics (simple keyword extraction)
topics = self._extract_topics(user_message)
self.interaction_stats["topics_discussed"].update(topics)
# Store in memory
self.memory.add_memory(
content=f"Interaction: User: '{user_message}' | Me: '{assistant_response}'",
metadata={
"type": "interaction",
"topics": ",".join(topics),
"feedback": user_feedback or "none",
"duration_ms": duration_ms or 0
}
)
def log_feedback(self, original_message: str, feedback: str, rating: int = None):
"""Log explicit feedback for learning."""
self.memory.add_memory(
content=f"Feedback on my response to '{original_message}': {feedback}",
metadata={
"type": "feedback",
"rating": rating or 0,
"original_message": original_message
}
)
if rating:
self.interaction_stats["user_satisfaction_scores"].append(rating)
def _extract_topics(self, text: str) -> List[str]:
"""Simple topic extraction."""
# Simple keyword-based extraction
text_lower = text.lower()
common_topics = [
"work", "family", "friends", "hobbies", "music", "movies", "books",
"travel", "food", "sports", "technology", "politics", "health",
"money", "education", "love", "stress", "happiness", "goals",
"dreams", "fears", "memories", "plans", "advice", "learning"
]
found = [t for t in common_topics if t in text_lower]
return found
def get_learning_insights(self) -> Dict:
"""Get insights from all interactions for improving the clone."""
feedback_memories = self.memory.search_by_type("feedback", n_results=100)
interaction_memories = self.memory.search_by_type("interaction", n_results=100)
insights = {
"total_interactions": len(interaction_memories),
"total_feedback": len(feedback_memories),
"common_topics": list(self.interaction_stats["topics_discussed"]),
"avg_satisfaction": (
sum(self.interaction_stats["user_satisfaction_scores"]) /
len(self.interaction_stats["user_satisfaction_scores"])
if self.interaction_stats["user_satisfaction_scores"] else 0
)
}
# Analyze feedback patterns
corrections = [m for m in feedback_memories if "wrong" in m["content"].lower() or "incorrect" in m["content"].lower()]
preferences = [m for m in feedback_memories if any(w in m["content"].lower() for w in ["prefer", "like", "want"])]
insights["common_corrections"] = [m["content"] for m in corrections[:5]]
insights["expressed_preferences"] = [m["content"] for m in preferences[:5]]
return insights