chmielvu's picture
feat: add production refinements (Phase 1-3)
4454066 verified
"""
FAISS-based semantic memory for task history and retrieval.
Features:
- Cosine similarity search via FAISS IndexFlatIP
- Top-k retrieval with configurable threshold
- Persistent save/load to disk
- Deduplication via embedding similarity
- Metadata storage alongside vectors
"""
try:
from .embeddings import EmbeddingModel
except ImportError:
# Fallback for direct module loading
from embeddings import EmbeddingModel
import faiss
import numpy as np
from typing import List, Dict, Any, Optional
import pickle
import os
import logging
logger = logging.getLogger(__name__)
class AgentMemory:
"""
FAISS-based semantic memory for task history and retrieval.
Features:
- Cosine similarity search (IndexFlatIP)
- Top-k retrieval with threshold
- Persistent save/load
- Deduplication via embedding similarity
Example:
>>> embedder = EmbeddingModel()
>>> memory = AgentMemory(embedder=embedder)
>>> memory.add("Calculate 2+2", "4")
>>> results = memory.search("What is 2+2?")
>>> print(results[0]["result"]) # "4"
"""
def __init__(
self,
dimension: int = 384,
k: int = 5,
similarity_threshold: float = 0.75,
dedup_threshold: float = 0.95,
embedder: Optional[EmbeddingModel] = None
):
"""
Initialize FAISS memory system.
Args:
dimension: Embedding dimension (default 384 for MiniLM-L6-v2)
k: Number of results to retrieve
similarity_threshold: Minimum cosine similarity for retrieval (0.0-1.0)
dedup_threshold: Cosine similarity threshold for deduplication (0.0-1.0)
embedder: Optional EmbeddingModel instance (creates new if None)
"""
self.k = k
self.similarity_threshold = similarity_threshold
self.dedup_threshold = dedup_threshold
# Initialize embedder
if embedder is None:
logger.info("Creating new EmbeddingModel for memory")
self.embedder = EmbeddingModel()
self.dimension = self.embedder.dimension
else:
self.embedder = embedder
self.dimension = dimension
# Initialize FAISS index (cosine similarity via inner product)
# Note: sentence-transformers already normalizes embeddings, so IndexFlatIP = cosine similarity
self.index = faiss.IndexFlatIP(self.dimension)
# Store metadata for each vector
self.metadata: List[Dict[str, Any]] = []
logger.info(
f"AgentMemory initialized (dim={self.dimension}, k={k}, "
f"threshold={similarity_threshold}, dedup={dedup_threshold})"
)
def add(
self,
task: str,
result: Any,
metadata: Optional[Dict[str, Any]] = None,
skip_dedup: bool = False
) -> bool:
"""
Add task-result pair to memory with deduplication.
Args:
task: Task description or query
result: Task result or answer
metadata: Optional metadata dict (e.g., execution_time, tokens)
skip_dedup: Skip deduplication check (for bulk loading)
Returns:
True if added, False if duplicate
Example:
>>> memory.add(
... task="Calculate 2+2",
... result="4",
... metadata={"execution_time": 0.5}
... )
True
"""
# Check for near-duplicates
if not skip_dedup and self.is_duplicate(task):
logger.debug(f"Skipping duplicate task: {task[:50]}...")
return False
# Embed task (already normalized by sentence-transformers)
embedding = self.embedder.embed_single(task)
# Ensure normalization (defensive - sentence-transformers should already normalize)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
# Add to FAISS
self.index.add(np.array([embedding], dtype=np.float32))
# Store metadata
meta = {
"task": task,
"result": str(result)[:1000], # Truncate to prevent memory bloat
"metadata": metadata or {}
}
self.metadata.append(meta)
logger.debug(f"Added to memory: {task[:50]}... (total: {self.index.ntotal})")
return True
def search(
self,
query: str,
k: Optional[int] = None,
threshold: Optional[float] = None
) -> List[Dict[str, Any]]:
"""
Search memory for similar tasks.
Args:
query: Search query
k: Number of results (default: self.k)
threshold: Similarity threshold (default: self.similarity_threshold)
Returns:
List of dicts with keys: task, result, similarity, metadata
Example:
>>> results = memory.search("What is 2+2?", k=3)
>>> for r in results:
... print(f"{r['task']} -> {r['result']} (sim={r['similarity']:.2f})")
"""
if self.index.ntotal == 0:
logger.debug("Memory is empty, no results")
return []
k = k or self.k
threshold = threshold or self.similarity_threshold
k = min(k, self.index.ntotal)
# Embed and normalize query
query_embedding = self.embedder.embed_single(query)
norm = np.linalg.norm(query_embedding)
if norm > 0:
query_embedding = query_embedding / norm
# Search FAISS
similarities, indices = self.index.search(
np.array([query_embedding], dtype=np.float32),
k
)
# Filter by threshold and format results
results = []
for similarity, idx in zip(similarities[0], indices[0]):
if similarity >= threshold:
meta = self.metadata[idx]
results.append({
"task": meta["task"],
"result": meta["result"],
"similarity": float(similarity),
"metadata": meta.get("metadata", {})
})
logger.debug(f"Memory search: {len(results)}/{k} results above threshold {threshold}")
return results
def is_duplicate(
self,
task: str,
threshold: Optional[float] = None
) -> bool:
"""
Check if task is very similar to existing tasks.
Args:
task: Task to check
threshold: Similarity threshold (default: self.dedup_threshold)
Returns:
True if duplicate found, False otherwise
Example:
>>> memory.add("Calculate 2+2", "4")
>>> memory.is_duplicate("Calculate 2+2") # True
>>> memory.is_duplicate("What is 3+3?") # False
"""
if self.index.ntotal == 0:
return False
threshold = threshold or self.dedup_threshold
# Search with k=1
similar = self.search(task, k=1, threshold=threshold)
is_dup = len(similar) > 0 and similar[0]["similarity"] >= threshold
if is_dup:
logger.debug(f"Duplicate detected (sim={similar[0]['similarity']:.3f}): {task[:50]}...")
return is_dup
def clear(self):
"""
Clear all memory.
Example:
>>> memory.clear()
>>> memory.get_stats()["total_items"]
0
"""
self.index.reset()
self.metadata.clear()
logger.info("Memory cleared")
def save(self, path: str):
"""
Save memory to disk.
Args:
path: Base path (will create .index and .meta files)
Example:
>>> memory.save("cache/agent_memory")
# Creates: cache/agent_memory.index, cache/agent_memory.meta
"""
# Create directory if needed
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
# Save FAISS index
faiss.write_index(self.index, f"{path}.index")
# Save metadata
with open(f"{path}.meta", "wb") as f:
pickle.dump({
"metadata": self.metadata,
"config": {
"dimension": self.dimension,
"k": self.k,
"similarity_threshold": self.similarity_threshold,
"dedup_threshold": self.dedup_threshold
}
}, f)
logger.info(f"Memory saved to {path} ({self.index.ntotal} items)")
def load(self, path: str):
"""
Load memory from disk.
Args:
path: Base path (will load .index and .meta files)
Example:
>>> memory = AgentMemory()
>>> memory.load("cache/agent_memory")
"""
if not os.path.exists(f"{path}.index"):
logger.warning(f"Memory file not found: {path}.index")
return
if not os.path.exists(f"{path}.meta"):
logger.warning(f"Metadata file not found: {path}.meta")
return
try:
# Load FAISS index
self.index = faiss.read_index(f"{path}.index")
# Load metadata
with open(f"{path}.meta", "rb") as f:
data = pickle.load(f)
self.metadata = data["metadata"]
# Load config if available
if "config" in data:
config = data["config"]
self.dimension = config.get("dimension", self.dimension)
self.k = config.get("k", self.k)
self.similarity_threshold = config.get("similarity_threshold", self.similarity_threshold)
self.dedup_threshold = config.get("dedup_threshold", self.dedup_threshold)
logger.info(f"Memory loaded from {path} ({self.index.ntotal} items)")
except Exception as e:
logger.error(f"Failed to load memory from {path}: {e}", exc_info=True)
raise
def get_stats(self) -> Dict[str, Any]:
"""
Get memory statistics.
Returns:
Dict with memory stats
Example:
>>> stats = memory.get_stats()
>>> print(f"Total items: {stats['total_items']}")
"""
return {
"total_items": self.index.ntotal,
"dimension": self.dimension,
"k": self.k,
"similarity_threshold": self.similarity_threshold,
"dedup_threshold": self.dedup_threshold
}
def get_all_tasks(self) -> List[str]:
"""
Get all task strings in memory.
Returns:
List of task strings
Example:
>>> tasks = memory.get_all_tasks()
>>> print(f"Memory contains {len(tasks)} tasks")
"""
return [meta["task"] for meta in self.metadata]
def __len__(self) -> int:
"""Get number of items in memory."""
return self.index.ntotal
def __repr__(self) -> str:
"""String representation."""
return f"AgentMemory(items={self.index.ntotal}, dim={self.dimension}, k={self.k})"