"""Long-term memory with persistent vector storage using ChromaDB.""" from __future__ import annotations import asyncio import hashlib import logging from datetime import datetime, timezone from typing import Any from uuid import uuid4 from pydantic import BaseModel, Field logger = logging.getLogger(__name__) class Document(BaseModel): """A document stored in long-term memory.""" id: str = Field(default_factory=lambda: str(uuid4())) content: str embedding: list[float] | None = None metadata: dict[str, Any] = Field(default_factory=dict) created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) model_config = {"arbitrary_types_allowed": True} class SearchResult(BaseModel): """A search result from long-term memory.""" document: Document score: float distance: float | None = None model_config = {"arbitrary_types_allowed": True} class LongTermMemory: """ Long-term persistent memory using ChromaDB for vector storage. This memory layer provides semantic search capabilities using embeddings. It persists across episodes and sessions, storing knowledge that should be retained long-term. Attributes: collection_name: Name of the ChromaDB collection. persist_directory: Directory for persistent storage. top_k: Default number of results to return from search. """ def __init__( self, collection_name: str = "scraperl_memory", persist_directory: str = "./data/chroma", top_k: int = 10, embedding_function: Any | None = None, ) -> None: """ Initialize long-term memory. Args: collection_name: Name of the ChromaDB collection. persist_directory: Directory for persistent storage. top_k: Default number of results to return from search. embedding_function: Optional custom embedding function. """ self.collection_name = collection_name self.persist_directory = persist_directory self.top_k = top_k self._embedding_function = embedding_function self._client: Any = None self._collection: Any = None self._initialized = False self._lock = asyncio.Lock() async def initialize(self) -> None: """ Initialize ChromaDB client and collection. This should be called before using other methods. """ if self._initialized: return async with self._lock: if self._initialized: return try: import chromadb from chromadb.config import Settings # Create persistent client self._client = chromadb.Client( Settings( chroma_db_impl="duckdb+parquet", persist_directory=self.persist_directory, anonymized_telemetry=False, ) ) # Get or create collection self._collection = self._client.get_or_create_collection( name=self.collection_name, embedding_function=self._embedding_function, metadata={"hnsw:space": "cosine"}, ) self._initialized = True logger.info( f"Initialized long-term memory: collection={self.collection_name}" ) except ImportError: logger.warning( "ChromaDB not available. Long-term memory will use in-memory fallback." ) self._use_fallback() except Exception as e: logger.warning( f"Failed to initialize ChromaDB: {e}. Using in-memory fallback." ) self._use_fallback() def _use_fallback(self) -> None: """Use in-memory fallback when ChromaDB is unavailable.""" self._client = None self._collection = None self._fallback_store: dict[str, Document] = {} self._initialized = True @property def is_initialized(self) -> bool: """Check if memory is initialized.""" return self._initialized @property def _using_fallback(self) -> bool: """Check if using in-memory fallback.""" return self._collection is None def _generate_id(self, content: str) -> str: """Generate a deterministic ID from content.""" return hashlib.sha256(content.encode()).hexdigest()[:16] async def store( self, content: str, document_id: str | None = None, metadata: dict[str, Any] | None = None, embedding: list[float] | None = None, ) -> Document: """ Store a document in long-term memory. Args: content: Text content to store. document_id: Optional custom ID. Generated from content if not provided. metadata: Optional metadata dictionary. embedding: Optional pre-computed embedding vector. Returns: The stored document. """ if not self._initialized: await self.initialize() async with self._lock: doc_id = document_id or self._generate_id(content) now = datetime.now(timezone.utc) document = Document( id=doc_id, content=content, embedding=embedding, metadata=metadata or {}, created_at=now, updated_at=now, ) if self._using_fallback: self._fallback_store[doc_id] = document else: # Store in ChromaDB try: self._collection.upsert( ids=[doc_id], documents=[content], metadatas=[ { **document.metadata, "created_at": now.isoformat(), "updated_at": now.isoformat(), } ], embeddings=[embedding] if embedding else None, ) except Exception as e: logger.error(f"Failed to store document: {e}") raise return document async def search( self, query: str, top_k: int | None = None, where: dict[str, Any] | None = None, query_embedding: list[float] | None = None, ) -> list[SearchResult]: """ Search for similar documents using semantic search. Args: query: Search query text. top_k: Number of results to return. Uses default if not specified. where: Optional metadata filter. query_embedding: Optional pre-computed query embedding. Returns: List of search results with scores. """ if not self._initialized: await self.initialize() k = top_k or self.top_k async with self._lock: if self._using_fallback: # Simple substring matching for fallback results = [] query_lower = query.lower() for doc in self._fallback_store.values(): if query_lower in doc.content.lower(): results.append( SearchResult(document=doc, score=1.0, distance=0.0) ) return results[:k] try: # Query ChromaDB query_params: dict[str, Any] = { "n_results": k, } if query_embedding: query_params["query_embeddings"] = [query_embedding] else: query_params["query_texts"] = [query] if where: query_params["where"] = where results = self._collection.query(**query_params) # Parse results search_results = [] if results and results.get("ids"): for i, doc_id in enumerate(results["ids"][0]): content = ( results["documents"][0][i] if results.get("documents") else "" ) metadata = ( results["metadatas"][0][i] if results.get("metadatas") else {} ) distance = ( results["distances"][0][i] if results.get("distances") else None ) doc = Document( id=doc_id, content=content, metadata=metadata, ) # Convert distance to score (cosine similarity) score = 1 - distance if distance is not None else 1.0 search_results.append( SearchResult( document=doc, score=score, distance=distance, ) ) return search_results except Exception as e: logger.error(f"Search failed: {e}") return [] async def get(self, document_id: str) -> Document | None: """ Retrieve a document by ID. Args: document_id: The document ID to retrieve. Returns: The document or None if not found. """ if not self._initialized: await self.initialize() async with self._lock: if self._using_fallback: return self._fallback_store.get(document_id) try: result = self._collection.get(ids=[document_id]) if result and result["ids"]: return Document( id=result["ids"][0], content=result["documents"][0] if result.get("documents") else "", metadata=result["metadatas"][0] if result.get("metadatas") else {}, ) return None except Exception as e: logger.error(f"Failed to get document: {e}") return None async def delete(self, document_id: str) -> bool: """ Delete a document from long-term memory. Args: document_id: The document ID to delete. Returns: True if document was deleted, False otherwise. """ if not self._initialized: await self.initialize() async with self._lock: if self._using_fallback: if document_id in self._fallback_store: del self._fallback_store[document_id] return True return False try: self._collection.delete(ids=[document_id]) return True except Exception as e: logger.error(f"Failed to delete document: {e}") return False async def delete_where(self, where: dict[str, Any]) -> int: """ Delete documents matching a metadata filter. Args: where: Metadata filter for documents to delete. Returns: Number of documents deleted. """ if not self._initialized: await self.initialize() async with self._lock: if self._using_fallback: to_delete = [] for doc_id, doc in self._fallback_store.items(): if all(doc.metadata.get(k) == v for k, v in where.items()): to_delete.append(doc_id) for doc_id in to_delete: del self._fallback_store[doc_id] return len(to_delete) try: # Get matching IDs first result = self._collection.get(where=where) if result and result["ids"]: self._collection.delete(ids=result["ids"]) return len(result["ids"]) return 0 except Exception as e: logger.error(f"Failed to delete documents: {e}") return 0 async def count(self) -> int: """ Get the total number of documents stored. Returns: Document count. """ if not self._initialized: await self.initialize() async with self._lock: if self._using_fallback: return len(self._fallback_store) try: return self._collection.count() except Exception as e: logger.error(f"Failed to count documents: {e}") return 0 async def clear(self) -> int: """ Clear all documents from memory. Returns: Number of documents that were cleared. """ if not self._initialized: await self.initialize() async with self._lock: if self._using_fallback: count = len(self._fallback_store) self._fallback_store.clear() return count try: count = self._collection.count() # Delete and recreate collection self._client.delete_collection(self.collection_name) self._collection = self._client.create_collection( name=self.collection_name, embedding_function=self._embedding_function, metadata={"hnsw:space": "cosine"}, ) return count except Exception as e: logger.error(f"Failed to clear memory: {e}") return 0 async def persist(self) -> None: """Persist changes to disk.""" if self._client and hasattr(self._client, "persist"): try: self._client.persist() except Exception as e: logger.error(f"Failed to persist memory: {e}") async def shutdown(self) -> None: """Shutdown long-term memory and persist data.""" if self._initialized and not self._using_fallback: await self.persist() self._initialized = False logger.info("Long-term memory shutdown complete") async def get_stats(self) -> dict[str, Any]: """ Get statistics about long-term memory. Returns: Dictionary with memory statistics. """ count = await self.count() return { "initialized": self._initialized, "using_fallback": self._using_fallback, "collection_name": self.collection_name, "persist_directory": self.persist_directory, "document_count": count, "top_k": self.top_k, }