MultiModalRag / utils /vector_store.py
irajkoohi's picture
chore: update app [space deploy]
6c21523
Raw
History Blame Contribute Delete
6.5 kB
"""
ChromaDB vector store manager with persistent storage.
Handles embedding via sentence-transformers with automatic
CUDA / MPS (Apple Silicon) / CPU device selection.
"""
import logging
import os
from typing import List, Dict, Any, Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from utils.device import get_device, device_info
logger = logging.getLogger(__name__)
COLLECTION_NAME = "multimodal_rag"
EMBED_MODEL = os.environ.get("EMBED_MODEL", "all-MiniLM-L6-v2")
class VectorStoreManager:
def __init__(self, persist_dir: str = "./vectorstore"):
self.persist_dir = persist_dir
os.makedirs(persist_dir, exist_ok=True)
# Detect best available device (CUDA > MPS > CPU)
self.device = get_device()
info = device_info()
logger.info(f"Embedding device: {info['label']}")
# Persistent Chroma client
self.client = chromadb.PersistentClient(
path=persist_dir,
settings=Settings(anonymized_telemetry=False),
)
# Load embedding model onto selected device
logger.info(f"Loading embedding model: {EMBED_MODEL} on {self.device}")
self.embedder = SentenceTransformer(EMBED_MODEL, device=self.device)
self.collection = self.client.get_or_create_collection(
name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
logger.info(f"Vector store ready — {self.collection.count()} chunks on {self.device}")
def _embed(self, texts: List[str]) -> List[List[float]]:
# normalize_embeddings=True improves cosine similarity quality
return self.embedder.encode(
texts,
show_progress_bar=False,
normalize_embeddings=True,
device=self.device,
).tolist()
def add_documents(self, chunks: List[Dict[str, Any]], source_name: str,
chunk_offset: int = 0) -> int:
"""Add document chunks. Returns number of chunks added.
chunk_offset: first chunk's ID index (used when calling in sub-batches
so that IDs remain globally unique across calls for the same source).
"""
if not chunks:
return 0
texts = [c["text"] for c in chunks]
metadatas = [c["metadata"] for c in chunks]
# Build unique IDs: source + absolute chunk index
ids = [
f"{source_name}__chunk_{chunk_offset + i}"
for i in range(len(chunks))
]
# Strip non-serializable fields from metadata (e.g. large base64 images)
clean_metas = []
for m in metadatas:
clean = {k: v for k, v in m.items() if isinstance(v, (str, int, float, bool))}
clean_metas.append(clean)
embeddings = self._embed(texts)
# Upsert in batches of 100
batch = 100
for i in range(0, len(texts), batch):
self.collection.upsert(
ids=ids[i:i+batch],
embeddings=embeddings[i:i+batch],
documents=texts[i:i+batch],
metadatas=clean_metas[i:i+batch],
)
logger.info(f"Added {len(chunks)} chunks for '{source_name}'")
return len(chunks)
def remove_document(self, source_name: str) -> int:
"""Remove all chunks belonging to a source file."""
results = self.collection.get(where={"source": source_name})
ids = results.get("ids", [])
if ids:
self.collection.delete(ids=ids)
logger.info(f"Removed {len(ids)} chunks for '{source_name}'")
return len(ids)
def clear_all(self) -> int:
"""Remove every chunk from the collection. Returns count removed."""
count = self.collection.count()
if count == 0:
return 0
all_ids = self.collection.get()["ids"]
if all_ids:
self.collection.delete(ids=all_ids)
logger.info(f"Cleared {count} chunks from collection")
return count
def query(self, query_text: str, n_results: int = 5, source_filter: List[str] = None) -> List[Dict[str, Any]]:
"""Semantic search. Returns list of {text, metadata, distance}.
source_filter: if provided, restrict results to chunks from these sources only.
"""
count = self.collection.count()
if count == 0:
return []
where = None
if source_filter:
where = (
{"source": {"$in": source_filter}}
if len(source_filter) > 1
else {"source": source_filter[0]}
)
# Count only filtered chunks to avoid asking for more than available
filtered_ids = self.collection.get(where=where)["ids"]
n = min(n_results, len(filtered_ids))
else:
n = min(n_results, count)
if n == 0:
return []
embedding = self._embed([query_text])[0]
results = self.collection.query(
query_embeddings=[embedding],
n_results=n,
include=["documents", "metadatas", "distances"],
where=where,
)
output = []
docs = results.get("documents", [[]])[0]
metas = results.get("metadatas", [[]])[0]
dists = results.get("distances", [[]])[0]
for doc, meta, dist in zip(docs, metas, dists):
output.append({"text": doc, "metadata": meta, "distance": dist})
return output
def query_per_source(self, query_text: str, n_per_source: int = 2) -> List[Dict[str, Any]]:
"""Fetch top n_per_source chunks from every source independently.
Ensures all documents are represented regardless of collection size.
"""
results = []
for source in self.list_sources():
chunks = self.query(query_text, n_results=n_per_source, source_filter=[source])
results.extend(chunks)
return results
def list_sources(self) -> List[str]:
"""List all unique source document names."""
if self.collection.count() == 0:
return []
results = self.collection.get(include=["metadatas"])
sources = set()
for m in results.get("metadatas", []):
if m and "source" in m:
sources.add(m["source"])
return sorted(sources)
def total_chunks(self) -> int:
return self.collection.count()