Spaces:
Paused
Paused
| """ | |
| ChromaDB vector store manager with persistent storage. | |
| Handles embedding via sentence-transformers with automatic | |
| CUDA / MPS (Apple Silicon) / CPU device selection. | |
| """ | |
| import logging | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from pathlib import Path | |
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| from utils.device import get_device, device_info | |
| logger = logging.getLogger(__name__) | |
| COLLECTION_NAME = "multimodal_rag" | |
| EMBED_MODEL = os.environ.get("EMBED_MODEL", "all-MiniLM-L6-v2") | |
| class VectorStoreManager: | |
| def __init__(self, persist_dir: str = "./vectorstore"): | |
| self.persist_dir = persist_dir | |
| os.makedirs(persist_dir, exist_ok=True) | |
| # Detect best available device (CUDA > MPS > CPU) | |
| self.device = get_device() | |
| info = device_info() | |
| logger.info(f"Embedding device: {info['label']}") | |
| # Persistent Chroma client | |
| self.client = chromadb.PersistentClient( | |
| path=persist_dir, | |
| settings=Settings(anonymized_telemetry=False), | |
| ) | |
| # Load embedding model onto selected device | |
| logger.info(f"Loading embedding model: {EMBED_MODEL} on {self.device}") | |
| self.embedder = SentenceTransformer(EMBED_MODEL, device=self.device) | |
| self.collection = self.client.get_or_create_collection( | |
| name=COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| logger.info(f"Vector store ready — {self.collection.count()} chunks on {self.device}") | |
| def _embed(self, texts: List[str]) -> List[List[float]]: | |
| # normalize_embeddings=True improves cosine similarity quality | |
| return self.embedder.encode( | |
| texts, | |
| show_progress_bar=False, | |
| normalize_embeddings=True, | |
| device=self.device, | |
| ).tolist() | |
| def add_documents(self, chunks: List[Dict[str, Any]], source_name: str, | |
| chunk_offset: int = 0) -> int: | |
| """Add document chunks. Returns number of chunks added. | |
| chunk_offset: first chunk's ID index (used when calling in sub-batches | |
| so that IDs remain globally unique across calls for the same source). | |
| """ | |
| if not chunks: | |
| return 0 | |
| texts = [c["text"] for c in chunks] | |
| metadatas = [c["metadata"] for c in chunks] | |
| # Build unique IDs: source + absolute chunk index | |
| ids = [ | |
| f"{source_name}__chunk_{chunk_offset + i}" | |
| for i in range(len(chunks)) | |
| ] | |
| # Strip non-serializable fields from metadata (e.g. large base64 images) | |
| clean_metas = [] | |
| for m in metadatas: | |
| clean = {k: v for k, v in m.items() if isinstance(v, (str, int, float, bool))} | |
| clean_metas.append(clean) | |
| embeddings = self._embed(texts) | |
| # Upsert in batches of 100 | |
| batch = 100 | |
| for i in range(0, len(texts), batch): | |
| self.collection.upsert( | |
| ids=ids[i:i+batch], | |
| embeddings=embeddings[i:i+batch], | |
| documents=texts[i:i+batch], | |
| metadatas=clean_metas[i:i+batch], | |
| ) | |
| logger.info(f"Added {len(chunks)} chunks for '{source_name}'") | |
| return len(chunks) | |
| def remove_document(self, source_name: str) -> int: | |
| """Remove all chunks belonging to a source file.""" | |
| results = self.collection.get(where={"source": source_name}) | |
| ids = results.get("ids", []) | |
| if ids: | |
| self.collection.delete(ids=ids) | |
| logger.info(f"Removed {len(ids)} chunks for '{source_name}'") | |
| return len(ids) | |
| def clear_all(self) -> int: | |
| """Remove every chunk from the collection. Returns count removed.""" | |
| count = self.collection.count() | |
| if count == 0: | |
| return 0 | |
| all_ids = self.collection.get()["ids"] | |
| if all_ids: | |
| self.collection.delete(ids=all_ids) | |
| logger.info(f"Cleared {count} chunks from collection") | |
| return count | |
| def query(self, query_text: str, n_results: int = 5, source_filter: List[str] = None) -> List[Dict[str, Any]]: | |
| """Semantic search. Returns list of {text, metadata, distance}. | |
| source_filter: if provided, restrict results to chunks from these sources only. | |
| """ | |
| count = self.collection.count() | |
| if count == 0: | |
| return [] | |
| where = None | |
| if source_filter: | |
| where = ( | |
| {"source": {"$in": source_filter}} | |
| if len(source_filter) > 1 | |
| else {"source": source_filter[0]} | |
| ) | |
| # Count only filtered chunks to avoid asking for more than available | |
| filtered_ids = self.collection.get(where=where)["ids"] | |
| n = min(n_results, len(filtered_ids)) | |
| else: | |
| n = min(n_results, count) | |
| if n == 0: | |
| return [] | |
| embedding = self._embed([query_text])[0] | |
| results = self.collection.query( | |
| query_embeddings=[embedding], | |
| n_results=n, | |
| include=["documents", "metadatas", "distances"], | |
| where=where, | |
| ) | |
| output = [] | |
| docs = results.get("documents", [[]])[0] | |
| metas = results.get("metadatas", [[]])[0] | |
| dists = results.get("distances", [[]])[0] | |
| for doc, meta, dist in zip(docs, metas, dists): | |
| output.append({"text": doc, "metadata": meta, "distance": dist}) | |
| return output | |
| def query_per_source(self, query_text: str, n_per_source: int = 2) -> List[Dict[str, Any]]: | |
| """Fetch top n_per_source chunks from every source independently. | |
| Ensures all documents are represented regardless of collection size. | |
| """ | |
| results = [] | |
| for source in self.list_sources(): | |
| chunks = self.query(query_text, n_results=n_per_source, source_filter=[source]) | |
| results.extend(chunks) | |
| return results | |
| def list_sources(self) -> List[str]: | |
| """List all unique source document names.""" | |
| if self.collection.count() == 0: | |
| return [] | |
| results = self.collection.get(include=["metadatas"]) | |
| sources = set() | |
| for m in results.get("metadatas", []): | |
| if m and "source" in m: | |
| sources.add(m["source"]) | |
| return sorted(sources) | |
| def total_chunks(self) -> int: | |
| return self.collection.count() | |