Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import threading | |
| import logging | |
| from queue import Queue, Empty | |
| from datetime import datetime | |
| from functools import lru_cache | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from time import perf_counter | |
| # Configuration | |
| MEMORY_FILE = os.environ.get("MEMORY_FILE", "memory.pkl") | |
| INDEX_FILE = os.environ.get("INDEX_FILE", "memory.index") | |
| EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "all-MiniLM-L6-v2") | |
| # Logging setup | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| # Load embedding model | |
| embedding_model = SentenceTransformer(EMBEDDING_MODEL) | |
| # Initialize memory store and FAISS index | |
| try: | |
| memory_data = pickle.load(open(MEMORY_FILE, "rb")) | |
| memory_index = faiss.read_index(INDEX_FILE) | |
| logging.info("Loaded existing memory and index.") | |
| except Exception: | |
| memory_data = [] | |
| dimension = embedding_model.get_sentence_embedding_dimension() | |
| memory_index = faiss.IndexFlatL2(dimension) | |
| logging.info("Initialized new memory and index.") | |
| # Queue and worker for async flushing | |
| _write_queue = Queue() | |
| def _flush_worker(): | |
| """Background thread: batch writes to disk.""" | |
| while True: | |
| batch = [] | |
| try: | |
| item = _write_queue.get(timeout=5) | |
| batch.append(item) | |
| except Empty: | |
| pass | |
| # Drain queue | |
| while not _write_queue.empty(): | |
| batch.append(_write_queue.get_nowait()) | |
| if batch: | |
| try: | |
| pickle.dump(memory_data, open(MEMORY_FILE, "wb")) | |
| faiss.write_index(memory_index, INDEX_FILE) | |
| logging.info(f"Flushed {len(batch)} entries to disk.") | |
| except Exception as e: | |
| logging.error(f"Flush error: {e}") | |
| # Start flush thread | |
| t = threading.Thread(target=_flush_worker, daemon=True) | |
| t.start() | |
| def get_embedding(text: str) -> np.ndarray: | |
| """Compute embedding with timing.""" | |
| start = perf_counter() | |
| vec = embedding_model.encode(text) | |
| elapsed = perf_counter() - start | |
| logging.info(f"get_embedding: {elapsed:.3f}s for '{text[:20]}...'") | |
| return vec | |
| def embed_and_store(text: str, agent: str = "system", topic: str = ""): | |
| """Embed text, add to FAISS and queue disk write.""" | |
| try: | |
| vec = get_embedding(text) | |
| memory_index.add(np.array([vec], dtype='float32')) | |
| memory_data.append({ | |
| "text": text, | |
| "agent": agent, | |
| "topic": topic, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| _write_queue.put(True) | |
| logging.info(f"Queued memory: {agent} / '{text[:20]}...'") | |
| except Exception as e: | |
| logging.error(f"embed_and_store error: {e}") | |
| def retrieve_relevant(query: str, k: int = 5) -> list: | |
| """Return top-k relevant memory entries.""" | |
| try: | |
| q_vec = get_embedding(query) | |
| D, I = memory_index.search(np.array([q_vec], dtype='float32'), k) | |
| results = [] | |
| for dist, idx in zip(D[0], I[0]): | |
| if idx < len(memory_data): | |
| entry = memory_data[idx] | |
| entry_copy = entry.copy() | |
| entry_copy['similarity'] = 1 - dist | |
| results.append(entry_copy) | |
| return results | |
| except Exception as e: | |
| logging.error(f"retrieve_relevant error: {e}") | |
| return [] |