Spaces:
Sleeping
Sleeping
| """ | |
| FAISS-based semantic memory for task history and retrieval. | |
| Features: | |
| - Cosine similarity search via FAISS IndexFlatIP | |
| - Top-k retrieval with configurable threshold | |
| - Persistent save/load to disk | |
| - Deduplication via embedding similarity | |
| - Metadata storage alongside vectors | |
| """ | |
| try: | |
| from .embeddings import EmbeddingModel | |
| except ImportError: | |
| # Fallback for direct module loading | |
| from embeddings import EmbeddingModel | |
| import faiss | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional | |
| import pickle | |
| import os | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class AgentMemory: | |
| """ | |
| FAISS-based semantic memory for task history and retrieval. | |
| Features: | |
| - Cosine similarity search (IndexFlatIP) | |
| - Top-k retrieval with threshold | |
| - Persistent save/load | |
| - Deduplication via embedding similarity | |
| Example: | |
| >>> embedder = EmbeddingModel() | |
| >>> memory = AgentMemory(embedder=embedder) | |
| >>> memory.add("Calculate 2+2", "4") | |
| >>> results = memory.search("What is 2+2?") | |
| >>> print(results[0]["result"]) # "4" | |
| """ | |
| def __init__( | |
| self, | |
| dimension: int = 384, | |
| k: int = 5, | |
| similarity_threshold: float = 0.75, | |
| dedup_threshold: float = 0.95, | |
| embedder: Optional[EmbeddingModel] = None | |
| ): | |
| """ | |
| Initialize FAISS memory system. | |
| Args: | |
| dimension: Embedding dimension (default 384 for MiniLM-L6-v2) | |
| k: Number of results to retrieve | |
| similarity_threshold: Minimum cosine similarity for retrieval (0.0-1.0) | |
| dedup_threshold: Cosine similarity threshold for deduplication (0.0-1.0) | |
| embedder: Optional EmbeddingModel instance (creates new if None) | |
| """ | |
| self.k = k | |
| self.similarity_threshold = similarity_threshold | |
| self.dedup_threshold = dedup_threshold | |
| # Initialize embedder | |
| if embedder is None: | |
| logger.info("Creating new EmbeddingModel for memory") | |
| self.embedder = EmbeddingModel() | |
| self.dimension = self.embedder.dimension | |
| else: | |
| self.embedder = embedder | |
| self.dimension = dimension | |
| # Initialize FAISS index (cosine similarity via inner product) | |
| # Note: sentence-transformers already normalizes embeddings, so IndexFlatIP = cosine similarity | |
| self.index = faiss.IndexFlatIP(self.dimension) | |
| # Store metadata for each vector | |
| self.metadata: List[Dict[str, Any]] = [] | |
| logger.info( | |
| f"AgentMemory initialized (dim={self.dimension}, k={k}, " | |
| f"threshold={similarity_threshold}, dedup={dedup_threshold})" | |
| ) | |
| def add( | |
| self, | |
| task: str, | |
| result: Any, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| skip_dedup: bool = False | |
| ) -> bool: | |
| """ | |
| Add task-result pair to memory with deduplication. | |
| Args: | |
| task: Task description or query | |
| result: Task result or answer | |
| metadata: Optional metadata dict (e.g., execution_time, tokens) | |
| skip_dedup: Skip deduplication check (for bulk loading) | |
| Returns: | |
| True if added, False if duplicate | |
| Example: | |
| >>> memory.add( | |
| ... task="Calculate 2+2", | |
| ... result="4", | |
| ... metadata={"execution_time": 0.5} | |
| ... ) | |
| True | |
| """ | |
| # Check for near-duplicates | |
| if not skip_dedup and self.is_duplicate(task): | |
| logger.debug(f"Skipping duplicate task: {task[:50]}...") | |
| return False | |
| # Embed task (already normalized by sentence-transformers) | |
| embedding = self.embedder.embed_single(task) | |
| # Ensure normalization (defensive - sentence-transformers should already normalize) | |
| norm = np.linalg.norm(embedding) | |
| if norm > 0: | |
| embedding = embedding / norm | |
| # Add to FAISS | |
| self.index.add(np.array([embedding], dtype=np.float32)) | |
| # Store metadata | |
| meta = { | |
| "task": task, | |
| "result": str(result)[:1000], # Truncate to prevent memory bloat | |
| "metadata": metadata or {} | |
| } | |
| self.metadata.append(meta) | |
| logger.debug(f"Added to memory: {task[:50]}... (total: {self.index.ntotal})") | |
| return True | |
| def search( | |
| self, | |
| query: str, | |
| k: Optional[int] = None, | |
| threshold: Optional[float] = None | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Search memory for similar tasks. | |
| Args: | |
| query: Search query | |
| k: Number of results (default: self.k) | |
| threshold: Similarity threshold (default: self.similarity_threshold) | |
| Returns: | |
| List of dicts with keys: task, result, similarity, metadata | |
| Example: | |
| >>> results = memory.search("What is 2+2?", k=3) | |
| >>> for r in results: | |
| ... print(f"{r['task']} -> {r['result']} (sim={r['similarity']:.2f})") | |
| """ | |
| if self.index.ntotal == 0: | |
| logger.debug("Memory is empty, no results") | |
| return [] | |
| k = k or self.k | |
| threshold = threshold or self.similarity_threshold | |
| k = min(k, self.index.ntotal) | |
| # Embed and normalize query | |
| query_embedding = self.embedder.embed_single(query) | |
| norm = np.linalg.norm(query_embedding) | |
| if norm > 0: | |
| query_embedding = query_embedding / norm | |
| # Search FAISS | |
| similarities, indices = self.index.search( | |
| np.array([query_embedding], dtype=np.float32), | |
| k | |
| ) | |
| # Filter by threshold and format results | |
| results = [] | |
| for similarity, idx in zip(similarities[0], indices[0]): | |
| if similarity >= threshold: | |
| meta = self.metadata[idx] | |
| results.append({ | |
| "task": meta["task"], | |
| "result": meta["result"], | |
| "similarity": float(similarity), | |
| "metadata": meta.get("metadata", {}) | |
| }) | |
| logger.debug(f"Memory search: {len(results)}/{k} results above threshold {threshold}") | |
| return results | |
| def is_duplicate( | |
| self, | |
| task: str, | |
| threshold: Optional[float] = None | |
| ) -> bool: | |
| """ | |
| Check if task is very similar to existing tasks. | |
| Args: | |
| task: Task to check | |
| threshold: Similarity threshold (default: self.dedup_threshold) | |
| Returns: | |
| True if duplicate found, False otherwise | |
| Example: | |
| >>> memory.add("Calculate 2+2", "4") | |
| >>> memory.is_duplicate("Calculate 2+2") # True | |
| >>> memory.is_duplicate("What is 3+3?") # False | |
| """ | |
| if self.index.ntotal == 0: | |
| return False | |
| threshold = threshold or self.dedup_threshold | |
| # Search with k=1 | |
| similar = self.search(task, k=1, threshold=threshold) | |
| is_dup = len(similar) > 0 and similar[0]["similarity"] >= threshold | |
| if is_dup: | |
| logger.debug(f"Duplicate detected (sim={similar[0]['similarity']:.3f}): {task[:50]}...") | |
| return is_dup | |
| def clear(self): | |
| """ | |
| Clear all memory. | |
| Example: | |
| >>> memory.clear() | |
| >>> memory.get_stats()["total_items"] | |
| 0 | |
| """ | |
| self.index.reset() | |
| self.metadata.clear() | |
| logger.info("Memory cleared") | |
| def save(self, path: str): | |
| """ | |
| Save memory to disk. | |
| Args: | |
| path: Base path (will create .index and .meta files) | |
| Example: | |
| >>> memory.save("cache/agent_memory") | |
| # Creates: cache/agent_memory.index, cache/agent_memory.meta | |
| """ | |
| # Create directory if needed | |
| os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) | |
| # Save FAISS index | |
| faiss.write_index(self.index, f"{path}.index") | |
| # Save metadata | |
| with open(f"{path}.meta", "wb") as f: | |
| pickle.dump({ | |
| "metadata": self.metadata, | |
| "config": { | |
| "dimension": self.dimension, | |
| "k": self.k, | |
| "similarity_threshold": self.similarity_threshold, | |
| "dedup_threshold": self.dedup_threshold | |
| } | |
| }, f) | |
| logger.info(f"Memory saved to {path} ({self.index.ntotal} items)") | |
| def load(self, path: str): | |
| """ | |
| Load memory from disk. | |
| Args: | |
| path: Base path (will load .index and .meta files) | |
| Example: | |
| >>> memory = AgentMemory() | |
| >>> memory.load("cache/agent_memory") | |
| """ | |
| if not os.path.exists(f"{path}.index"): | |
| logger.warning(f"Memory file not found: {path}.index") | |
| return | |
| if not os.path.exists(f"{path}.meta"): | |
| logger.warning(f"Metadata file not found: {path}.meta") | |
| return | |
| try: | |
| # Load FAISS index | |
| self.index = faiss.read_index(f"{path}.index") | |
| # Load metadata | |
| with open(f"{path}.meta", "rb") as f: | |
| data = pickle.load(f) | |
| self.metadata = data["metadata"] | |
| # Load config if available | |
| if "config" in data: | |
| config = data["config"] | |
| self.dimension = config.get("dimension", self.dimension) | |
| self.k = config.get("k", self.k) | |
| self.similarity_threshold = config.get("similarity_threshold", self.similarity_threshold) | |
| self.dedup_threshold = config.get("dedup_threshold", self.dedup_threshold) | |
| logger.info(f"Memory loaded from {path} ({self.index.ntotal} items)") | |
| except Exception as e: | |
| logger.error(f"Failed to load memory from {path}: {e}", exc_info=True) | |
| raise | |
| def get_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get memory statistics. | |
| Returns: | |
| Dict with memory stats | |
| Example: | |
| >>> stats = memory.get_stats() | |
| >>> print(f"Total items: {stats['total_items']}") | |
| """ | |
| return { | |
| "total_items": self.index.ntotal, | |
| "dimension": self.dimension, | |
| "k": self.k, | |
| "similarity_threshold": self.similarity_threshold, | |
| "dedup_threshold": self.dedup_threshold | |
| } | |
| def get_all_tasks(self) -> List[str]: | |
| """ | |
| Get all task strings in memory. | |
| Returns: | |
| List of task strings | |
| Example: | |
| >>> tasks = memory.get_all_tasks() | |
| >>> print(f"Memory contains {len(tasks)} tasks") | |
| """ | |
| return [meta["task"] for meta in self.metadata] | |
| def __len__(self) -> int: | |
| """Get number of items in memory.""" | |
| return self.index.ntotal | |
| def __repr__(self) -> str: | |
| """String representation.""" | |
| return f"AgentMemory(items={self.index.ntotal}, dim={self.dimension}, k={self.k})" | |