Spaces:
Running
Running
| """ | |
| ChromaDB vector store operations. | |
| Per-user collections for data isolation. | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import chromadb | |
| from chromadb.config import Settings as ChromaSettings | |
| from app.config import get_settings | |
| from app.rag.embeddings import get_embedding_model | |
| logger = logging.getLogger(__name__) | |
| settings = get_settings() | |
| # ββ Singleton ChromaDB client ββββββββββββββββββββββββ | |
| _chroma_client = None | |
| def get_chroma_client() -> chromadb.ClientAPI: | |
| """Get or create persistent ChromaDB client.""" | |
| global _chroma_client | |
| if _chroma_client is None: | |
| import os | |
| os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True) | |
| _chroma_client = chromadb.PersistentClient( | |
| path=settings.CHROMA_PERSIST_DIR, | |
| settings=ChromaSettings(anonymized_telemetry=False), | |
| ) | |
| logger.info(f"ChromaDB initialized at {settings.CHROMA_PERSIST_DIR}") | |
| return _chroma_client | |
| def get_collection_name(user_id: str) -> str: | |
| """Generate a valid collection name for a user.""" | |
| # ChromaDB collection names must be 3-63 chars, alphanumeric + underscores | |
| clean_id = user_id.replace("-", "_") | |
| name = f"user_{clean_id}" | |
| # Truncate if too long | |
| return name[:63] | |
| def store_chunks( | |
| chunks: List[Dict[str, Any]], | |
| document_id: str, | |
| filename: str, | |
| user_id: str, | |
| ) -> int: | |
| """ | |
| Embed and store document chunks in ChromaDB. | |
| Returns the number of chunks stored. | |
| """ | |
| if not chunks: | |
| return 0 | |
| client = get_chroma_client() | |
| embedding_model = get_embedding_model() | |
| collection_name = get_collection_name(user_id) | |
| collection = client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| # ββ Prepare batch data βββββββββββββββββββββββββββ | |
| texts = [chunk["text"] for chunk in chunks] | |
| ids = [f"{document_id}_{chunk['chunk_index']}" for chunk in chunks] | |
| metadatas = [ | |
| { | |
| "text": chunk["text"], | |
| "filename": filename, | |
| "document_id": document_id, | |
| "page": chunk["page"], | |
| "chunk_index": chunk["chunk_index"], | |
| } | |
| for chunk in chunks | |
| ] | |
| # ββ Embed and upsert in batches ββββββββββββββββββ | |
| batch_size = 50 | |
| total_stored = 0 | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| batch_ids = ids[i:i + batch_size] | |
| batch_metadatas = metadatas[i:i + batch_size] | |
| # Generate embeddings | |
| embeddings = embedding_model.embed_documents(batch_texts) | |
| collection.add( | |
| ids=batch_ids, | |
| embeddings=embeddings, | |
| metadatas=batch_metadatas, | |
| documents=batch_texts, | |
| ) | |
| total_stored += len(batch_texts) | |
| logger.info(f"Stored {total_stored} chunks for document {document_id}") | |
| return total_stored | |
| def query_chunks( | |
| query_embedding: List[float], | |
| user_id: str, | |
| document_id: Optional[str] = None, | |
| top_k: int = 10, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Query ChromaDB for relevant chunks. | |
| Returns list of dicts with text, metadata, and distance. | |
| """ | |
| client = get_chroma_client() | |
| collection_name = get_collection_name(user_id) | |
| try: | |
| collection = client.get_collection(name=collection_name) | |
| except Exception: | |
| logger.warning(f"Collection {collection_name} not found") | |
| return [] | |
| # ββ Build filter βββββββββββββββββββββββββββββββββ | |
| where_filter = None | |
| if document_id: | |
| where_filter = {"document_id": {"$eq": document_id}} | |
| # ββ Query ββββββββββββββββββββββββββββββββββββββββ | |
| results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| where=where_filter, | |
| include=["documents", "metadatas", "distances"], | |
| ) | |
| # ββ Format results βββββββββββββββββββββββββββββββ | |
| chunks = [] | |
| if results and results["documents"] and results["documents"][0]: | |
| for i, doc in enumerate(results["documents"][0]): | |
| metadata = results["metadatas"][0][i] if results["metadatas"] else {} | |
| distance = results["distances"][0][i] if results["distances"] else 0 | |
| # Convert cosine distance to similarity score (0-1) | |
| similarity = 1 - distance | |
| chunks.append({ | |
| "text": doc, | |
| "filename": metadata.get("filename", ""), | |
| "document_id": metadata.get("document_id", ""), | |
| "page": metadata.get("page", 1), | |
| "score": round(similarity, 4), | |
| }) | |
| return chunks | |
| def delete_document_chunks(document_id: str, user_id: str): | |
| """Delete all chunks for a specific document.""" | |
| client = get_chroma_client() | |
| collection_name = get_collection_name(user_id) | |
| try: | |
| collection = client.get_collection(name=collection_name) | |
| # Get all IDs for this document | |
| results = collection.get( | |
| where={"document_id": {"$eq": document_id}}, | |
| include=[], | |
| ) | |
| if results["ids"]: | |
| collection.delete(ids=results["ids"]) | |
| logger.info(f"Deleted {len(results['ids'])} chunks for document {document_id}") | |
| except Exception as e: | |
| logger.warning(f"Error deleting chunks: {e}") | |
| def delete_user_collection(user_id: str): | |
| """Delete entire collection for a user.""" | |
| client = get_chroma_client() | |
| collection_name = get_collection_name(user_id) | |
| try: | |
| client.delete_collection(name=collection_name) | |
| logger.info(f"Deleted collection {collection_name}") | |
| except Exception as e: | |
| logger.warning(f"Error deleting collection: {e}") | |