mvi-ai-engine / memory /long_term.py
Musombi's picture
Update memory/long_term.py
7608950
import os
import json
import numpy as np
from datetime import datetime
import faiss
import torch
class LongTermMemory:
"""
FAISS-powered semantic long-term memory.
Stores:
• vector embeddings
• associated text
• metadata
• timestamps
"""
def __init__(
self,
index_path="memory/storage/ltm.index",
meta_path="memory/storage/ltm_meta.json",
dim: int = 128
):
self.index_path = index_path
self.meta_path = meta_path
self.dim = dim
os.makedirs(os.path.dirname(index_path), exist_ok=True)
# ===== LOAD OR CREATE FAISS INDEX =====
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
print("[LTM] Loaded existing FAISS index.")
else:
self.index = faiss.IndexFlatIP(dim)
print("[LTM] Created new FAISS index.")
# ===== LOAD METADATA =====
self.meta_store = self._load_meta()
# ---------------------------------------------------
# INTERNAL UTILITIES
# ---------------------------------------------------
def _load_meta(self):
if os.path.exists(self.meta_path):
try:
with open(self.meta_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Filter corrupted or legacy entries
clean = []
for entry in data:
if "embedding" in entry and "text" in entry:
clean.append(entry)
return clean
except Exception:
print("[LTM] Metadata corrupted — starting fresh.")
return []
return []
def _save_meta(self):
with open(self.meta_path, "w", encoding="utf-8") as f:
json.dump(self.meta_store, f, indent=2)
def _normalize(self, vec: np.ndarray):
norm = np.linalg.norm(vec, axis=1, keepdims=True) + 1e-8
return vec / norm
# ---------------------------------------------------
# STORE MEMORY
# ---------------------------------------------------
def store(self, embedding: torch.Tensor, text: str, meta=None):
"""
Store embedding + text + metadata
"""
if isinstance(embedding, torch.Tensor):
embedding = embedding.detach().cpu().numpy()
embedding = self._normalize(embedding)
# Ensure float32 for FAISS
embedding = embedding.astype("float32")
# --- Add vector to FAISS ---
self.index.add(embedding)
faiss.write_index(self.index, self.index_path)
entry = {
"text": text,
"embedding": embedding.tolist(),
"meta": meta or {},
"timestamp": datetime.utcnow().isoformat()
}
self.meta_store.append(entry)
self._save_meta()
# ---------------------------------------------------
# RETRIEVE MEMORY
# ---------------------------------------------------
def retrieve(self, query_embedding: torch.Tensor, k: int = 5):
"""
Semantic search for top-k relevant memories.
"""
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.detach().cpu().numpy()
query_embedding = self._normalize(query_embedding)
query_embedding = query_embedding.astype("float32")
if self.index.ntotal == 0:
return []
distances, indices = self.index.search(query_embedding, k)
results = []
for i, idx in enumerate(indices[0]):
if idx < len(self.meta_store):
entry = self.meta_store[idx]
if "embedding" not in entry:
continue
results.append({
"text": entry.get("text", ""),
"embedding": entry["embedding"],
"score": float(distances[0][i]),
"meta": entry.get("meta", {}),
"timestamp": entry.get("timestamp")
})
return results
# ---------------------------------------------------
# VECTOR RETRIEVAL (FOR ATTENTION FUSION)
# ---------------------------------------------------
def retrieve_vectors(self, query_embedding: torch.Tensor, k: int = 5):
"""
Returns only embeddings for fast attention fusion.
"""
memories = self.retrieve(query_embedding, k)
if len(memories) == 0:
return None
vectors = []
for m in memories:
vec = np.array(m["embedding"], dtype=np.float32)
vectors.append(vec)
stacked = np.stack(vectors)
return torch.tensor(stacked)
# ---------------------------------------------------
# UTILITY
# ---------------------------------------------------
def size(self):
"""Number of stored memories"""
return self.index.ntotal
def all(self):
"""Debug view — avoid using in production"""
return self.meta_store