Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import numpy as np | |
| from datetime import datetime | |
| import faiss | |
| import torch | |
| class LongTermMemory: | |
| """ | |
| FAISS-powered semantic long-term memory. | |
| Stores: | |
| • vector embeddings | |
| • associated text | |
| • metadata | |
| • timestamps | |
| """ | |
| def __init__( | |
| self, | |
| index_path="memory/storage/ltm.index", | |
| meta_path="memory/storage/ltm_meta.json", | |
| dim: int = 128 | |
| ): | |
| self.index_path = index_path | |
| self.meta_path = meta_path | |
| self.dim = dim | |
| os.makedirs(os.path.dirname(index_path), exist_ok=True) | |
| # ===== LOAD OR CREATE FAISS INDEX ===== | |
| if os.path.exists(self.index_path): | |
| self.index = faiss.read_index(self.index_path) | |
| print("[LTM] Loaded existing FAISS index.") | |
| else: | |
| self.index = faiss.IndexFlatIP(dim) | |
| print("[LTM] Created new FAISS index.") | |
| # ===== LOAD METADATA ===== | |
| self.meta_store = self._load_meta() | |
| # --------------------------------------------------- | |
| # INTERNAL UTILITIES | |
| # --------------------------------------------------- | |
| def _load_meta(self): | |
| if os.path.exists(self.meta_path): | |
| try: | |
| with open(self.meta_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # Filter corrupted or legacy entries | |
| clean = [] | |
| for entry in data: | |
| if "embedding" in entry and "text" in entry: | |
| clean.append(entry) | |
| return clean | |
| except Exception: | |
| print("[LTM] Metadata corrupted — starting fresh.") | |
| return [] | |
| return [] | |
| def _save_meta(self): | |
| with open(self.meta_path, "w", encoding="utf-8") as f: | |
| json.dump(self.meta_store, f, indent=2) | |
| def _normalize(self, vec: np.ndarray): | |
| norm = np.linalg.norm(vec, axis=1, keepdims=True) + 1e-8 | |
| return vec / norm | |
| # --------------------------------------------------- | |
| # STORE MEMORY | |
| # --------------------------------------------------- | |
| def store(self, embedding: torch.Tensor, text: str, meta=None): | |
| """ | |
| Store embedding + text + metadata | |
| """ | |
| if isinstance(embedding, torch.Tensor): | |
| embedding = embedding.detach().cpu().numpy() | |
| embedding = self._normalize(embedding) | |
| # Ensure float32 for FAISS | |
| embedding = embedding.astype("float32") | |
| # --- Add vector to FAISS --- | |
| self.index.add(embedding) | |
| faiss.write_index(self.index, self.index_path) | |
| entry = { | |
| "text": text, | |
| "embedding": embedding.tolist(), | |
| "meta": meta or {}, | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| self.meta_store.append(entry) | |
| self._save_meta() | |
| # --------------------------------------------------- | |
| # RETRIEVE MEMORY | |
| # --------------------------------------------------- | |
| def retrieve(self, query_embedding: torch.Tensor, k: int = 5): | |
| """ | |
| Semantic search for top-k relevant memories. | |
| """ | |
| if isinstance(query_embedding, torch.Tensor): | |
| query_embedding = query_embedding.detach().cpu().numpy() | |
| query_embedding = self._normalize(query_embedding) | |
| query_embedding = query_embedding.astype("float32") | |
| if self.index.ntotal == 0: | |
| return [] | |
| distances, indices = self.index.search(query_embedding, k) | |
| results = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx < len(self.meta_store): | |
| entry = self.meta_store[idx] | |
| if "embedding" not in entry: | |
| continue | |
| results.append({ | |
| "text": entry.get("text", ""), | |
| "embedding": entry["embedding"], | |
| "score": float(distances[0][i]), | |
| "meta": entry.get("meta", {}), | |
| "timestamp": entry.get("timestamp") | |
| }) | |
| return results | |
| # --------------------------------------------------- | |
| # VECTOR RETRIEVAL (FOR ATTENTION FUSION) | |
| # --------------------------------------------------- | |
| def retrieve_vectors(self, query_embedding: torch.Tensor, k: int = 5): | |
| """ | |
| Returns only embeddings for fast attention fusion. | |
| """ | |
| memories = self.retrieve(query_embedding, k) | |
| if len(memories) == 0: | |
| return None | |
| vectors = [] | |
| for m in memories: | |
| vec = np.array(m["embedding"], dtype=np.float32) | |
| vectors.append(vec) | |
| stacked = np.stack(vectors) | |
| return torch.tensor(stacked) | |
| # --------------------------------------------------- | |
| # UTILITY | |
| # --------------------------------------------------- | |
| def size(self): | |
| """Number of stored memories""" | |
| return self.index.ntotal | |
| def all(self): | |
| """Debug view — avoid using in production""" | |
| return self.meta_store |