""" FAISS-based semantic memory for task history and retrieval. Features: - Cosine similarity search via FAISS IndexFlatIP - Top-k retrieval with configurable threshold - Persistent save/load to disk - Deduplication via embedding similarity - Metadata storage alongside vectors """ try: from .embeddings import EmbeddingModel except ImportError: # Fallback for direct module loading from embeddings import EmbeddingModel import faiss import numpy as np from typing import List, Dict, Any, Optional import pickle import os import logging logger = logging.getLogger(__name__) class AgentMemory: """ FAISS-based semantic memory for task history and retrieval. Features: - Cosine similarity search (IndexFlatIP) - Top-k retrieval with threshold - Persistent save/load - Deduplication via embedding similarity Example: >>> embedder = EmbeddingModel() >>> memory = AgentMemory(embedder=embedder) >>> memory.add("Calculate 2+2", "4") >>> results = memory.search("What is 2+2?") >>> print(results[0]["result"]) # "4" """ def __init__( self, dimension: int = 384, k: int = 5, similarity_threshold: float = 0.75, dedup_threshold: float = 0.95, embedder: Optional[EmbeddingModel] = None ): """ Initialize FAISS memory system. Args: dimension: Embedding dimension (default 384 for MiniLM-L6-v2) k: Number of results to retrieve similarity_threshold: Minimum cosine similarity for retrieval (0.0-1.0) dedup_threshold: Cosine similarity threshold for deduplication (0.0-1.0) embedder: Optional EmbeddingModel instance (creates new if None) """ self.k = k self.similarity_threshold = similarity_threshold self.dedup_threshold = dedup_threshold # Initialize embedder if embedder is None: logger.info("Creating new EmbeddingModel for memory") self.embedder = EmbeddingModel() self.dimension = self.embedder.dimension else: self.embedder = embedder self.dimension = dimension # Initialize FAISS index (cosine similarity via inner product) # Note: sentence-transformers already normalizes embeddings, so IndexFlatIP = cosine similarity self.index = faiss.IndexFlatIP(self.dimension) # Store metadata for each vector self.metadata: List[Dict[str, Any]] = [] logger.info( f"AgentMemory initialized (dim={self.dimension}, k={k}, " f"threshold={similarity_threshold}, dedup={dedup_threshold})" ) def add( self, task: str, result: Any, metadata: Optional[Dict[str, Any]] = None, skip_dedup: bool = False ) -> bool: """ Add task-result pair to memory with deduplication. Args: task: Task description or query result: Task result or answer metadata: Optional metadata dict (e.g., execution_time, tokens) skip_dedup: Skip deduplication check (for bulk loading) Returns: True if added, False if duplicate Example: >>> memory.add( ... task="Calculate 2+2", ... result="4", ... metadata={"execution_time": 0.5} ... ) True """ # Check for near-duplicates if not skip_dedup and self.is_duplicate(task): logger.debug(f"Skipping duplicate task: {task[:50]}...") return False # Embed task (already normalized by sentence-transformers) embedding = self.embedder.embed_single(task) # Ensure normalization (defensive - sentence-transformers should already normalize) norm = np.linalg.norm(embedding) if norm > 0: embedding = embedding / norm # Add to FAISS self.index.add(np.array([embedding], dtype=np.float32)) # Store metadata meta = { "task": task, "result": str(result)[:1000], # Truncate to prevent memory bloat "metadata": metadata or {} } self.metadata.append(meta) logger.debug(f"Added to memory: {task[:50]}... (total: {self.index.ntotal})") return True def search( self, query: str, k: Optional[int] = None, threshold: Optional[float] = None ) -> List[Dict[str, Any]]: """ Search memory for similar tasks. Args: query: Search query k: Number of results (default: self.k) threshold: Similarity threshold (default: self.similarity_threshold) Returns: List of dicts with keys: task, result, similarity, metadata Example: >>> results = memory.search("What is 2+2?", k=3) >>> for r in results: ... print(f"{r['task']} -> {r['result']} (sim={r['similarity']:.2f})") """ if self.index.ntotal == 0: logger.debug("Memory is empty, no results") return [] k = k or self.k threshold = threshold or self.similarity_threshold k = min(k, self.index.ntotal) # Embed and normalize query query_embedding = self.embedder.embed_single(query) norm = np.linalg.norm(query_embedding) if norm > 0: query_embedding = query_embedding / norm # Search FAISS similarities, indices = self.index.search( np.array([query_embedding], dtype=np.float32), k ) # Filter by threshold and format results results = [] for similarity, idx in zip(similarities[0], indices[0]): if similarity >= threshold: meta = self.metadata[idx] results.append({ "task": meta["task"], "result": meta["result"], "similarity": float(similarity), "metadata": meta.get("metadata", {}) }) logger.debug(f"Memory search: {len(results)}/{k} results above threshold {threshold}") return results def is_duplicate( self, task: str, threshold: Optional[float] = None ) -> bool: """ Check if task is very similar to existing tasks. Args: task: Task to check threshold: Similarity threshold (default: self.dedup_threshold) Returns: True if duplicate found, False otherwise Example: >>> memory.add("Calculate 2+2", "4") >>> memory.is_duplicate("Calculate 2+2") # True >>> memory.is_duplicate("What is 3+3?") # False """ if self.index.ntotal == 0: return False threshold = threshold or self.dedup_threshold # Search with k=1 similar = self.search(task, k=1, threshold=threshold) is_dup = len(similar) > 0 and similar[0]["similarity"] >= threshold if is_dup: logger.debug(f"Duplicate detected (sim={similar[0]['similarity']:.3f}): {task[:50]}...") return is_dup def clear(self): """ Clear all memory. Example: >>> memory.clear() >>> memory.get_stats()["total_items"] 0 """ self.index.reset() self.metadata.clear() logger.info("Memory cleared") def save(self, path: str): """ Save memory to disk. Args: path: Base path (will create .index and .meta files) Example: >>> memory.save("cache/agent_memory") # Creates: cache/agent_memory.index, cache/agent_memory.meta """ # Create directory if needed os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) # Save FAISS index faiss.write_index(self.index, f"{path}.index") # Save metadata with open(f"{path}.meta", "wb") as f: pickle.dump({ "metadata": self.metadata, "config": { "dimension": self.dimension, "k": self.k, "similarity_threshold": self.similarity_threshold, "dedup_threshold": self.dedup_threshold } }, f) logger.info(f"Memory saved to {path} ({self.index.ntotal} items)") def load(self, path: str): """ Load memory from disk. Args: path: Base path (will load .index and .meta files) Example: >>> memory = AgentMemory() >>> memory.load("cache/agent_memory") """ if not os.path.exists(f"{path}.index"): logger.warning(f"Memory file not found: {path}.index") return if not os.path.exists(f"{path}.meta"): logger.warning(f"Metadata file not found: {path}.meta") return try: # Load FAISS index self.index = faiss.read_index(f"{path}.index") # Load metadata with open(f"{path}.meta", "rb") as f: data = pickle.load(f) self.metadata = data["metadata"] # Load config if available if "config" in data: config = data["config"] self.dimension = config.get("dimension", self.dimension) self.k = config.get("k", self.k) self.similarity_threshold = config.get("similarity_threshold", self.similarity_threshold) self.dedup_threshold = config.get("dedup_threshold", self.dedup_threshold) logger.info(f"Memory loaded from {path} ({self.index.ntotal} items)") except Exception as e: logger.error(f"Failed to load memory from {path}: {e}", exc_info=True) raise def get_stats(self) -> Dict[str, Any]: """ Get memory statistics. Returns: Dict with memory stats Example: >>> stats = memory.get_stats() >>> print(f"Total items: {stats['total_items']}") """ return { "total_items": self.index.ntotal, "dimension": self.dimension, "k": self.k, "similarity_threshold": self.similarity_threshold, "dedup_threshold": self.dedup_threshold } def get_all_tasks(self) -> List[str]: """ Get all task strings in memory. Returns: List of task strings Example: >>> tasks = memory.get_all_tasks() >>> print(f"Memory contains {len(tasks)} tasks") """ return [meta["task"] for meta in self.metadata] def __len__(self) -> int: """Get number of items in memory.""" return self.index.ntotal def __repr__(self) -> str: """String representation.""" return f"AgentMemory(items={self.index.ntotal}, dim={self.dimension}, k={self.k})"