import os import json import numpy as np from datetime import datetime import faiss import torch class LongTermMemory: """ FAISS-powered semantic long-term memory. Stores: • vector embeddings • associated text • metadata • timestamps """ def __init__( self, index_path="memory/storage/ltm.index", meta_path="memory/storage/ltm_meta.json", dim: int = 128 ): self.index_path = index_path self.meta_path = meta_path self.dim = dim os.makedirs(os.path.dirname(index_path), exist_ok=True) # ===== LOAD OR CREATE FAISS INDEX ===== if os.path.exists(self.index_path): self.index = faiss.read_index(self.index_path) print("[LTM] Loaded existing FAISS index.") else: self.index = faiss.IndexFlatIP(dim) print("[LTM] Created new FAISS index.") # ===== LOAD METADATA ===== self.meta_store = self._load_meta() # --------------------------------------------------- # INTERNAL UTILITIES # --------------------------------------------------- def _load_meta(self): if os.path.exists(self.meta_path): try: with open(self.meta_path, "r", encoding="utf-8") as f: data = json.load(f) # Filter corrupted or legacy entries clean = [] for entry in data: if "embedding" in entry and "text" in entry: clean.append(entry) return clean except Exception: print("[LTM] Metadata corrupted — starting fresh.") return [] return [] def _save_meta(self): with open(self.meta_path, "w", encoding="utf-8") as f: json.dump(self.meta_store, f, indent=2) def _normalize(self, vec: np.ndarray): norm = np.linalg.norm(vec, axis=1, keepdims=True) + 1e-8 return vec / norm # --------------------------------------------------- # STORE MEMORY # --------------------------------------------------- def store(self, embedding: torch.Tensor, text: str, meta=None): """ Store embedding + text + metadata """ if isinstance(embedding, torch.Tensor): embedding = embedding.detach().cpu().numpy() embedding = self._normalize(embedding) # Ensure float32 for FAISS embedding = embedding.astype("float32") # --- Add vector to FAISS --- self.index.add(embedding) faiss.write_index(self.index, self.index_path) entry = { "text": text, "embedding": embedding.tolist(), "meta": meta or {}, "timestamp": datetime.utcnow().isoformat() } self.meta_store.append(entry) self._save_meta() # --------------------------------------------------- # RETRIEVE MEMORY # --------------------------------------------------- def retrieve(self, query_embedding: torch.Tensor, k: int = 5): """ Semantic search for top-k relevant memories. """ if isinstance(query_embedding, torch.Tensor): query_embedding = query_embedding.detach().cpu().numpy() query_embedding = self._normalize(query_embedding) query_embedding = query_embedding.astype("float32") if self.index.ntotal == 0: return [] distances, indices = self.index.search(query_embedding, k) results = [] for i, idx in enumerate(indices[0]): if idx < len(self.meta_store): entry = self.meta_store[idx] if "embedding" not in entry: continue results.append({ "text": entry.get("text", ""), "embedding": entry["embedding"], "score": float(distances[0][i]), "meta": entry.get("meta", {}), "timestamp": entry.get("timestamp") }) return results # --------------------------------------------------- # VECTOR RETRIEVAL (FOR ATTENTION FUSION) # --------------------------------------------------- def retrieve_vectors(self, query_embedding: torch.Tensor, k: int = 5): """ Returns only embeddings for fast attention fusion. """ memories = self.retrieve(query_embedding, k) if len(memories) == 0: return None vectors = [] for m in memories: vec = np.array(m["embedding"], dtype=np.float32) vectors.append(vec) stacked = np.stack(vectors) return torch.tensor(stacked) # --------------------------------------------------- # UTILITY # --------------------------------------------------- def size(self): """Number of stored memories""" return self.index.ntotal def all(self): """Debug view — avoid using in production""" return self.meta_store