Spaces:
Runtime error
Runtime error
File size: 5,136 Bytes
c1e438c 871b8bf c1e438c 7608950 871b8bf 7608950 871b8bf 7608950 c1e438c 7608950 c1e438c 7608950 c1e438c 7608950 c1e438c 7608950 871b8bf c1e438c 7608950 c1e438c 7608950 871b8bf c1e438c 871b8bf c1e438c 871b8bf c1e438c 7608950 871b8bf 7608950 c1e438c 7608950 871b8bf c1e438c 7608950 c1e438c 7608950 c1e438c 871b8bf c1e438c 7608950 871b8bf c1e438c 7608950 c1e438c 871b8bf c1e438c 7608950 c1e438c 871b8bf 7608950 c1e438c 7608950 c1e438c 7608950 871b8bf c1e438c 7608950 871b8bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import os
import json
import numpy as np
from datetime import datetime
import faiss
import torch
class LongTermMemory:
"""
FAISS-powered semantic long-term memory.
Stores:
• vector embeddings
• associated text
• metadata
• timestamps
"""
def __init__(
self,
index_path="memory/storage/ltm.index",
meta_path="memory/storage/ltm_meta.json",
dim: int = 128
):
self.index_path = index_path
self.meta_path = meta_path
self.dim = dim
os.makedirs(os.path.dirname(index_path), exist_ok=True)
# ===== LOAD OR CREATE FAISS INDEX =====
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
print("[LTM] Loaded existing FAISS index.")
else:
self.index = faiss.IndexFlatIP(dim)
print("[LTM] Created new FAISS index.")
# ===== LOAD METADATA =====
self.meta_store = self._load_meta()
# ---------------------------------------------------
# INTERNAL UTILITIES
# ---------------------------------------------------
def _load_meta(self):
if os.path.exists(self.meta_path):
try:
with open(self.meta_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Filter corrupted or legacy entries
clean = []
for entry in data:
if "embedding" in entry and "text" in entry:
clean.append(entry)
return clean
except Exception:
print("[LTM] Metadata corrupted — starting fresh.")
return []
return []
def _save_meta(self):
with open(self.meta_path, "w", encoding="utf-8") as f:
json.dump(self.meta_store, f, indent=2)
def _normalize(self, vec: np.ndarray):
norm = np.linalg.norm(vec, axis=1, keepdims=True) + 1e-8
return vec / norm
# ---------------------------------------------------
# STORE MEMORY
# ---------------------------------------------------
def store(self, embedding: torch.Tensor, text: str, meta=None):
"""
Store embedding + text + metadata
"""
if isinstance(embedding, torch.Tensor):
embedding = embedding.detach().cpu().numpy()
embedding = self._normalize(embedding)
# Ensure float32 for FAISS
embedding = embedding.astype("float32")
# --- Add vector to FAISS ---
self.index.add(embedding)
faiss.write_index(self.index, self.index_path)
entry = {
"text": text,
"embedding": embedding.tolist(),
"meta": meta or {},
"timestamp": datetime.utcnow().isoformat()
}
self.meta_store.append(entry)
self._save_meta()
# ---------------------------------------------------
# RETRIEVE MEMORY
# ---------------------------------------------------
def retrieve(self, query_embedding: torch.Tensor, k: int = 5):
"""
Semantic search for top-k relevant memories.
"""
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.detach().cpu().numpy()
query_embedding = self._normalize(query_embedding)
query_embedding = query_embedding.astype("float32")
if self.index.ntotal == 0:
return []
distances, indices = self.index.search(query_embedding, k)
results = []
for i, idx in enumerate(indices[0]):
if idx < len(self.meta_store):
entry = self.meta_store[idx]
if "embedding" not in entry:
continue
results.append({
"text": entry.get("text", ""),
"embedding": entry["embedding"],
"score": float(distances[0][i]),
"meta": entry.get("meta", {}),
"timestamp": entry.get("timestamp")
})
return results
# ---------------------------------------------------
# VECTOR RETRIEVAL (FOR ATTENTION FUSION)
# ---------------------------------------------------
def retrieve_vectors(self, query_embedding: torch.Tensor, k: int = 5):
"""
Returns only embeddings for fast attention fusion.
"""
memories = self.retrieve(query_embedding, k)
if len(memories) == 0:
return None
vectors = []
for m in memories:
vec = np.array(m["embedding"], dtype=np.float32)
vectors.append(vec)
stacked = np.stack(vectors)
return torch.tensor(stacked)
# ---------------------------------------------------
# UTILITY
# ---------------------------------------------------
def size(self):
"""Number of stored memories"""
return self.index.ntotal
def all(self):
"""Debug view — avoid using in production"""
return self.meta_store |