PDF-Assit_RAG / backend /app /rag /vectorstore.py
Param20h's picture
Changes Pro max
26a0c00 unverified
"""
ChromaDB vector store operations.
Per-user collections for data isolation.
"""
import logging
from typing import List, Dict, Any, Optional
import chromadb
from chromadb.config import Settings as ChromaSettings
from app.config import get_settings
from app.rag.embeddings import get_embedding_model
logger = logging.getLogger(__name__)
settings = get_settings()
# ── Singleton ChromaDB client ────────────────────────
_chroma_client = None
def get_chroma_client() -> chromadb.ClientAPI:
"""Get or create persistent ChromaDB client."""
global _chroma_client
if _chroma_client is None:
import os
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
_chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_PERSIST_DIR,
settings=ChromaSettings(anonymized_telemetry=False),
)
logger.info(f"ChromaDB initialized at {settings.CHROMA_PERSIST_DIR}")
return _chroma_client
def get_collection_name(user_id: str) -> str:
"""Generate a valid collection name for a user."""
# ChromaDB collection names must be 3-63 chars, alphanumeric + underscores
clean_id = user_id.replace("-", "_")
name = f"user_{clean_id}"
# Truncate if too long
return name[:63]
def store_chunks(
chunks: List[Dict[str, Any]],
document_id: str,
filename: str,
user_id: str,
) -> int:
"""
Embed and store document chunks in ChromaDB.
Returns the number of chunks stored.
"""
if not chunks:
return 0
client = get_chroma_client()
embedding_model = get_embedding_model()
collection_name = get_collection_name(user_id)
collection = client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"},
)
# ── Prepare batch data ───────────────────────────
texts = [chunk["text"] for chunk in chunks]
ids = [f"{document_id}_{chunk['chunk_index']}" for chunk in chunks]
metadatas = [
{
"text": chunk["text"],
"filename": filename,
"document_id": document_id,
"page": chunk["page"],
"chunk_index": chunk["chunk_index"],
}
for chunk in chunks
]
# ── Embed and upsert in batches ──────────────────
batch_size = 50
total_stored = 0
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_metadatas = metadatas[i:i + batch_size]
# Generate embeddings
embeddings = embedding_model.embed_documents(batch_texts)
collection.add(
ids=batch_ids,
embeddings=embeddings,
metadatas=batch_metadatas,
documents=batch_texts,
)
total_stored += len(batch_texts)
logger.info(f"Stored {total_stored} chunks for document {document_id}")
return total_stored
def query_chunks(
query_embedding: List[float],
user_id: str,
document_id: Optional[str] = None,
top_k: int = 10,
) -> List[Dict[str, Any]]:
"""
Query ChromaDB for relevant chunks.
Returns list of dicts with text, metadata, and distance.
"""
client = get_chroma_client()
collection_name = get_collection_name(user_id)
try:
collection = client.get_collection(name=collection_name)
except Exception:
logger.warning(f"Collection {collection_name} not found")
return []
# ── Build filter ─────────────────────────────────
where_filter = None
if document_id:
where_filter = {"document_id": {"$eq": document_id}}
# ── Query ────────────────────────────────────────
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where_filter,
include=["documents", "metadatas", "distances"],
)
# ── Format results ───────────────────────────────
chunks = []
if results and results["documents"] and results["documents"][0]:
for i, doc in enumerate(results["documents"][0]):
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
distance = results["distances"][0][i] if results["distances"] else 0
# Convert cosine distance to similarity score (0-1)
similarity = 1 - distance
chunks.append({
"text": doc,
"filename": metadata.get("filename", ""),
"document_id": metadata.get("document_id", ""),
"page": metadata.get("page", 1),
"score": round(similarity, 4),
})
return chunks
def delete_document_chunks(document_id: str, user_id: str):
"""Delete all chunks for a specific document."""
client = get_chroma_client()
collection_name = get_collection_name(user_id)
try:
collection = client.get_collection(name=collection_name)
# Get all IDs for this document
results = collection.get(
where={"document_id": {"$eq": document_id}},
include=[],
)
if results["ids"]:
collection.delete(ids=results["ids"])
logger.info(f"Deleted {len(results['ids'])} chunks for document {document_id}")
except Exception as e:
logger.warning(f"Error deleting chunks: {e}")
def delete_user_collection(user_id: str):
"""Delete entire collection for a user."""
client = get_chroma_client()
collection_name = get_collection_name(user_id)
try:
client.delete_collection(name=collection_name)
logger.info(f"Deleted collection {collection_name}")
except Exception as e:
logger.warning(f"Error deleting collection: {e}")