Spaces:
Configuration error
Configuration error
| import shutil | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import tqdm | |
| from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from loguru import logger | |
| from app.config.models.configs import Config | |
| from app.parsers.splitter import Document | |
| from app.utils import torch_device | |
| class ChromaDenseVectorDB: | |
| def __init__(self, persist_folder: str, config: Config): | |
| self._persist_folder = persist_folder | |
| self._config = config | |
| logger.info(f"Embedding model config: {config}") | |
| self._embeddings = SentenceTransformerEmbeddings(model_name=config.embeddings.embedding_model.model_name, | |
| model_kwargs={"device": torch_device()}) | |
| self.batch_size = 200 | |
| self._retriever = None | |
| self._vectordb = None | |
| def retriever(self): | |
| if self._retriever is None: | |
| self._retriever = self._load_retriever() | |
| return self._retriever | |
| def vectordb(self): | |
| if self._vectordb is None: | |
| self._vectordb = Chroma( | |
| persist_directory=self._persist_folder, | |
| embedding_function=self._embeddings, | |
| ) | |
| return self._vectordb | |
| def generate_embeddings( | |
| self, | |
| docs: List[Document], | |
| clear_persist_folder: bool = True, | |
| ): | |
| if clear_persist_folder: | |
| pf = Path(self._persist_folder) | |
| if pf.exists() and pf.is_dir(): | |
| logger.warning(f"Deleting the content of: {pf}") | |
| shutil.rmtree(pf) | |
| logger.info("Generating and persisting the embeddings..") | |
| vectordb = None | |
| for group in tqdm.tqdm( | |
| chunker(docs, size=self.batch_size), | |
| total=int(len(docs) / self.batch_size), | |
| ): | |
| ids = [d.metadata["document_id"] for d in group] | |
| if vectordb is None: | |
| vectordb = Chroma.from_documents( | |
| documents=group, | |
| embedding=self._embeddings, | |
| ids=ids, | |
| persist_directory=self._persist_folder, | |
| ) | |
| else: | |
| vectordb.add_texts( | |
| texts=[doc.page_content for doc in group], | |
| embedding=self._embeddings, | |
| ids=ids, | |
| metadatas=[doc.metadata for doc in group], | |
| ) | |
| logger.info("Generated embeddings. Persisting...") | |
| if vectordb is not None: | |
| vectordb.persist() | |
| def _load_retriever(self, **kwargs): | |
| return self.vectordb.as_retriever(**kwargs) | |
| def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: | |
| results = self.retriever.vectorstore.get(ids=document_ids, include=["metadatas", "documents"]) # type: ignore | |
| docs = [ | |
| Document(page_content=d, metadata=m) | |
| for d, m in zip(results["documents"], results["metadatas"]) | |
| ] | |
| return docs | |
| def similarity_search_with_relevance_scores( | |
| self, query: str, filter: Optional[dict] | |
| ) -> List[Tuple[Document, float]]: | |
| if isinstance(filter, dict) and len(filter) > 1: | |
| filter = {"$and": [{key: {"$eq": value}} for key, value in filter.items()]} | |
| print("Filter = ", filter) | |
| return self.retriever.vectorstore.similarity_search_with_relevance_scores( | |
| query, k=self._config.semantic_search.max_k, filter=filter | |
| ) | |
| def chunker(seq, size): | |
| return (seq[pos: pos + size] for pos in range(0, len(seq), size)) | |