Spaces:
Sleeping
Sleeping
| import chromadb | |
| from chromadb.config import Settings | |
| from langchain_chroma import Chroma | |
| from langchain_core.documents import Document | |
| from typing import List, Dict, Optional | |
| import os | |
| class VectorStoreManager: | |
| """Manages ChromaDB vector store for persistent storage.""" | |
| def __init__(self, | |
| persist_dir: str = "./chroma_db", | |
| collection_name: str = "pdf_documents", | |
| embeddings=None): | |
| """ | |
| Initialize vector store. | |
| Args: | |
| persist_dir: Directory for persistent storage | |
| collection_name: Name of the collection | |
| embeddings: LangChain embeddings instance | |
| """ | |
| self.persist_dir = persist_dir | |
| self.collection_name = collection_name | |
| self.embeddings = embeddings | |
| os.makedirs(persist_dir, exist_ok=True) | |
| # Initialize ChromaDB persistent client | |
| self.client = chromadb.PersistentClient(path=persist_dir) | |
| # Initialize LangChain Chroma wrapper | |
| self.vector_store = Chroma( | |
| client=self.client, | |
| collection_name=collection_name, | |
| embedding_function=embeddings, | |
| persist_directory=persist_dir | |
| ) | |
| print(f"Vector store initialized: {persist_dir}/{collection_name}") | |
| def add_documents(self, documents: List[Document], batch_size: int = 50): | |
| """ | |
| Add documents to vector store. | |
| Args: | |
| documents: List of LangChain Document objects | |
| batch_size: Number of documents per batch | |
| """ | |
| # Process in batches | |
| for i in range(0, len(documents), batch_size): | |
| batch = documents[i:i + batch_size] | |
| try: | |
| self.vector_store.add_documents(batch) | |
| print(f"Added {len(batch)} documents (batch {i//batch_size + 1})") | |
| except Exception as e: | |
| print(f"Error adding documents: {e}") | |
| def search(self, query: str, k: int = 5) -> List[Dict]: | |
| """ | |
| Search for similar documents. | |
| Args: | |
| query: Search query | |
| k: Number of results to return | |
| Returns: | |
| List of documents with scores | |
| """ | |
| results = self.vector_store.similarity_search_with_score(query, k=k) | |
| search_results = [] | |
| for doc, score in results: | |
| search_results.append({ | |
| "content": doc.page_content, | |
| "metadata": doc.metadata, | |
| "similarity": score | |
| }) | |
| return search_results | |
| def get_retriever(self, search_kwargs: Optional[Dict] = None): | |
| """Get retriever for RAG chain.""" | |
| if search_kwargs is None: | |
| search_kwargs = {"k": 5} | |
| return self.vector_store.as_retriever(search_kwargs=search_kwargs) | |
| def collection_count(self) -> int: | |
| """Get number of documents in collection.""" | |
| try: | |
| collection = self.client.get_collection(self.collection_name) | |
| return collection.count() | |
| except Exception as e: | |
| print(f"Error getting collection count: {e}") | |
| return 0 | |
| def clear_collection(self): | |
| """Clear all documents from collection.""" | |
| try: | |
| self.client.delete_collection(self.collection_name) | |
| self.vector_store = Chroma( | |
| client=self.client, | |
| collection_name=self.collection_name, | |
| embedding_function=self.embeddings, | |
| persist_directory=self.persist_dir | |
| ) | |
| print(f"Collection cleared: {self.collection_name}") | |
| except Exception as e: | |
| print(f"Error clearing collection: {e}") |