Spaces:
Running
Running
| """ | |
| Semantic / Vector Memory β RAG Layer | |
| ===================================== | |
| Long-term knowledge stored in ChromaDB with sentence-transformer embeddings. | |
| Also persists each entry as a Markdown file under memory/vector/*.md | |
| for human-readability and version control. | |
| This is the RAG backbone: | |
| β’ Add documents β embed + store | |
| β’ Query by natural language β cosine similarity search | |
| β’ Full CRUD with automatic re-embedding on update | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from .models import MemoryEntry, MemoryTier, SearchResult | |
| logger = logging.getLogger(__name__) | |
| # ββ optional heavy deps (graceful fallback) ββββββββββββββββββ | |
| try: | |
| import chromadb | |
| from chromadb.config import Settings as ChromaSettings | |
| CHROMA_AVAILABLE = True | |
| except ImportError: | |
| CHROMA_AVAILABLE = False | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| ST_AVAILABLE = True | |
| except ImportError: | |
| ST_AVAILABLE = False | |
| class _SentenceTransformerEmbedder: | |
| """Wraps sentence-transformers for ChromaDB's EmbeddingFunction protocol.""" | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
| if not ST_AVAILABLE: | |
| raise ImportError("sentence-transformers is required for semantic memory") | |
| self.model = SentenceTransformer(model_name) | |
| self.model_name = model_name | |
| def __call__(self, input: List[str]) -> List[List[float]]: | |
| embeddings = self.model.encode(input, show_progress_bar=False) | |
| return embeddings.tolist() | |
| def name(self) -> str: | |
| """Required by ChromaDB EmbeddingFunction protocol.""" | |
| return f"sentence-transformers_{self.model_name}" | |
| class SemanticMemory: | |
| """ChromaDB-backed vector store with Markdown file mirror.""" | |
| COLLECTION_NAME = "memory_semantic" | |
| DEFAULT_MODEL = "all-MiniLM-L6-v2" | |
| def __init__( | |
| self, | |
| vector_dir: str = "memory/vector", | |
| md_dir: str = "memory/vector/docs", | |
| model_name: str = DEFAULT_MODEL, | |
| collection_name: str = COLLECTION_NAME, | |
| ): | |
| self.vector_dir = Path(vector_dir) | |
| self.md_dir = Path(md_dir) | |
| self.vector_dir.mkdir(parents=True, exist_ok=True) | |
| self.md_dir.mkdir(parents=True, exist_ok=True) | |
| self.model_name = model_name | |
| self.collection_name = collection_name | |
| # ChromaDB setup | |
| if CHROMA_AVAILABLE: | |
| self._client = chromadb.PersistentClient( | |
| path=str(self.vector_dir / "chroma_db"), | |
| ) | |
| # Embedding function | |
| if ST_AVAILABLE: | |
| self._embed_fn = _SentenceTransformerEmbedder(model_name) | |
| self._collection = self._client.get_or_create_collection( | |
| name=collection_name, | |
| embedding_function=self._embed_fn, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| else: | |
| # fall back to Chroma's built-in default embedder | |
| self._collection = self._client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| self._embed_fn = None | |
| logger.info( | |
| "SemanticMemory ready β ChromaDB @ %s | model=%s | docs=%d", | |
| self.vector_dir, model_name, self._collection.count(), | |
| ) | |
| else: | |
| self._client = None | |
| self._collection = None | |
| self._embed_fn = None | |
| logger.warning("chromadb not installed β semantic memory operates in file-only mode") | |
| # ββ CRUD βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def create( | |
| self, | |
| content: str, | |
| title: str = "", | |
| tags: Optional[List[str]] = None, | |
| importance: float = 0.5, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| source: str = "", | |
| ) -> MemoryEntry: | |
| """Add a new document to the vector store + Markdown mirror.""" | |
| entry = MemoryEntry( | |
| content=content, | |
| title=title or content[:80], | |
| tier=MemoryTier.SEMANTIC, | |
| tags=tags or [], | |
| importance=importance, | |
| metadata=metadata or {}, | |
| source=source, | |
| created_at=datetime.utcnow().isoformat(), | |
| updated_at=datetime.utcnow().isoformat(), | |
| ) | |
| self._upsert_vector(entry) | |
| self._persist_md(entry) | |
| return entry | |
| def read(self, entry_id: str) -> Optional[MemoryEntry]: | |
| """Retrieve by ID.""" | |
| if self._collection is None: | |
| return self._read_from_md(entry_id) | |
| try: | |
| result = self._collection.get(ids=[entry_id], include=["documents", "metadatas"]) | |
| if not result["ids"]: | |
| return None | |
| entry = self._result_to_entry(result, 0) | |
| entry.access_count += 1 | |
| entry.updated_at = datetime.utcnow().isoformat() | |
| self._upsert_vector(entry) | |
| self._persist_md(entry) | |
| return entry | |
| except Exception as exc: | |
| logger.error("read failed: %s", exc) | |
| return self._read_from_md(entry_id) | |
| def update(self, entry_id: str, **kwargs) -> Optional[MemoryEntry]: | |
| """Update fields and re-embed if content changed.""" | |
| entry = self.read(entry_id) | |
| if not entry: | |
| return None | |
| for k, v in kwargs.items(): | |
| if hasattr(entry, k) and k not in ("id", "tier", "created_at"): | |
| setattr(entry, k, v) | |
| entry.updated_at = datetime.utcnow().isoformat() | |
| self._upsert_vector(entry) | |
| self._persist_md(entry) | |
| return entry | |
| def delete(self, entry_id: str) -> bool: | |
| """Remove from vector store and disk.""" | |
| if self._collection is not None: | |
| try: | |
| self._collection.delete(ids=[entry_id]) | |
| except Exception: | |
| pass | |
| md_path = self.md_dir / f"{entry_id}.md" | |
| if md_path.exists(): | |
| md_path.unlink() | |
| return True | |
| return False | |
| # ββ search / RAG βββββββββββββββββββββββββββββββββββββββββ | |
| def search( | |
| self, | |
| query: str, | |
| limit: int = 5, | |
| where: Optional[Dict[str, Any]] = None, | |
| ) -> List[SearchResult]: | |
| """Semantic similarity search. This is the RAG retrieval endpoint.""" | |
| if self._collection is None: | |
| return self._keyword_fallback(query, limit) | |
| kwargs: Dict[str, Any] = { | |
| "query_texts": [query], | |
| "n_results": min(limit, self._collection.count() or 1), | |
| "include": ["documents", "metadatas", "distances"], | |
| } | |
| if where: | |
| kwargs["where"] = where | |
| try: | |
| results = self._collection.query(**kwargs) | |
| except Exception as exc: | |
| logger.error("vector search failed: %s", exc) | |
| return self._keyword_fallback(query, limit) | |
| search_results: List[SearchResult] = [] | |
| if results and results["ids"] and results["ids"][0]: | |
| for idx in range(len(results["ids"][0])): | |
| entry = self._query_result_to_entry(results, idx) | |
| dist = results["distances"][0][idx] if results.get("distances") else 0 | |
| score = max(0.0, 1.0 - dist) # cosine distance β similarity | |
| search_results.append(SearchResult(entry=entry, score=score, distance=dist)) | |
| return search_results | |
| def list_entries(self, limit: int = 100, tag: Optional[str] = None) -> List[MemoryEntry]: | |
| """List all stored entries (up to limit).""" | |
| if self._collection is None: | |
| return self._list_from_md(limit, tag) | |
| result = self._collection.get( | |
| include=["documents", "metadatas"], | |
| limit=limit, | |
| ) | |
| entries = [] | |
| for idx in range(len(result["ids"])): | |
| entry = self._result_to_entry(result, idx) | |
| if tag and tag not in entry.tags: | |
| continue | |
| entries.append(entry) | |
| return entries | |
| def count(self) -> int: | |
| if self._collection is not None: | |
| return self._collection.count() | |
| return len(list(self.md_dir.glob("*.md"))) | |
| # ββ internals ββββββββββββββββββββββββββββββββββββββββββββ | |
| def _upsert_vector(self, entry: MemoryEntry): | |
| if self._collection is None: | |
| return | |
| meta = { | |
| "title": entry.title, | |
| "tier": entry.tier.value, | |
| "tags": json.dumps(entry.tags), | |
| "importance": entry.importance, | |
| "access_count": entry.access_count, | |
| "created_at": entry.created_at, | |
| "updated_at": entry.updated_at, | |
| "source": entry.source, | |
| } | |
| self._collection.upsert( | |
| ids=[entry.id], | |
| documents=[entry.content], | |
| metadatas=[meta], | |
| ) | |
| def _persist_md(self, entry: MemoryEntry): | |
| path = self.md_dir / f"{entry.id}.md" | |
| path.write_text(entry.to_markdown(), encoding="utf-8") | |
| def _read_from_md(self, entry_id: str) -> Optional[MemoryEntry]: | |
| path = self.md_dir / f"{entry_id}.md" | |
| if not path.exists(): | |
| return None | |
| text = path.read_text(encoding="utf-8") | |
| return MemoryEntry.from_markdown(text) | |
| def _result_to_entry(self, result: dict, idx: int) -> MemoryEntry: | |
| meta = result["metadatas"][idx] if result.get("metadatas") else {} | |
| doc = result["documents"][idx] if result.get("documents") else "" | |
| entry_id = result["ids"][idx] | |
| tags = [] | |
| if "tags" in meta: | |
| try: | |
| tags = json.loads(meta["tags"]) | |
| except (json.JSONDecodeError, TypeError): | |
| tags = [] | |
| return MemoryEntry( | |
| id=entry_id, | |
| content=doc, | |
| title=meta.get("title", ""), | |
| tier=MemoryTier.SEMANTIC, | |
| tags=tags, | |
| importance=float(meta.get("importance", 0.5)), | |
| access_count=int(meta.get("access_count", 0)), | |
| created_at=meta.get("created_at", ""), | |
| updated_at=meta.get("updated_at", ""), | |
| source=meta.get("source", ""), | |
| ) | |
| def _query_result_to_entry(self, results: dict, idx: int) -> MemoryEntry: | |
| meta = results["metadatas"][0][idx] if results.get("metadatas") else {} | |
| doc = results["documents"][0][idx] if results.get("documents") else "" | |
| entry_id = results["ids"][0][idx] | |
| tags = [] | |
| if "tags" in meta: | |
| try: | |
| tags = json.loads(meta["tags"]) | |
| except (json.JSONDecodeError, TypeError): | |
| tags = [] | |
| return MemoryEntry( | |
| id=entry_id, | |
| content=doc, | |
| title=meta.get("title", ""), | |
| tier=MemoryTier.SEMANTIC, | |
| tags=tags, | |
| importance=float(meta.get("importance", 0.5)), | |
| access_count=int(meta.get("access_count", 0)), | |
| created_at=meta.get("created_at", ""), | |
| updated_at=meta.get("updated_at", ""), | |
| source=meta.get("source", ""), | |
| ) | |
| def _keyword_fallback(self, query: str, limit: int) -> List[SearchResult]: | |
| """When ChromaDB is unavailable, fall back to keyword search over MD files.""" | |
| q = query.lower() | |
| results: List[SearchResult] = [] | |
| for md_file in self.md_dir.glob("*.md"): | |
| try: | |
| text = md_file.read_text(encoding="utf-8") | |
| if q in text.lower(): | |
| entry = MemoryEntry.from_markdown(text) | |
| entry.tier = MemoryTier.SEMANTIC | |
| results.append(SearchResult(entry=entry, score=0.5)) | |
| if len(results) >= limit: | |
| break | |
| except Exception: | |
| pass | |
| return results | |
| def _list_from_md(self, limit: int, tag: Optional[str]) -> List[MemoryEntry]: | |
| entries: List[MemoryEntry] = [] | |
| for md_file in sorted(self.md_dir.glob("*.md"), reverse=True): | |
| try: | |
| text = md_file.read_text(encoding="utf-8") | |
| entry = MemoryEntry.from_markdown(text) | |
| entry.tier = MemoryTier.SEMANTIC | |
| if tag and tag not in entry.tags: | |
| continue | |
| entries.append(entry) | |
| if len(entries) >= limit: | |
| break | |
| except Exception: | |
| pass | |
| return entries | |