""" Contain Wrapper Class for ChormaDB client, that can process and store documents and retrive document chunks. """ # for chromaDB __import__("pysqlite3") import sys sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") from io import BytesIO from typing import List from typing_extensions import Literal import uuid import warnings import chromadb import re from .utils import ( generate_file_id, chunk_document, generate_embedding, extract_content_from_docx, extract_content_from_pdf, ) class AdvancedClient: def __init__(self, vector_database_path: str = "vectorDB") -> None: self.client = chromadb.PersistentClient(path=vector_database_path) self.exsisting_collections = [ collection.name for collection in self.client.list_collections() ] self.selected_collections: List[str] = [] def create_or_get_collection( self, file_names: List[str], file_types: List[Literal["pdf", "docx"]], file_datas, ): collections = [] for data in zip(file_names, file_types, file_datas): file_name, file_type, file_data = data file_id = generate_file_id(file_bytes=file_data) file_exisis = file_id in self.exsisting_collections if file_exisis: collection = file_id else: collection = self.client.create_collection(name=file_id) file_buffer = BytesIO(file_data) if file_type == "pdf": document, pil_images = extract_content_from_pdf(file_buffer) chunks = chunk_document(document) ids = [f"{uuid.uuid4()}_id_{x}" for x in range(1, len(chunks) + 1)] embeddings = generate_embedding( chunks, embedding_model="znbang/bge:small-en-v1.5-q8_0" ) metadatas = [] for chunk in chunks: imgs_found = re.findall( pattern=r"", string=chunk ) chunk_imgs = [] if len(imgs_found) > 0: for img in imgs_found: chunk_imgs.append(pil_images[int(img)]) metadatas.append( {"images": str(chunk_imgs), "file_name": file_name} ) elif file_type == "docx": document = extract_content_from_docx(file_buffer) chunks = chunk_document(document) ids = [f"{uuid.uuid4()}_id_{x}" for x in range(1, len(chunks) + 1)] embeddings = generate_embedding( chunks, embedding_model="znbang/bge:small-en-v1.5-q8_0" ) metadatas = [{"file_name": file_name} for _ in chunks] else: raise Exception( f"Given format '.{file_type}' is currently not supported." ) collection.add( ids=ids, embeddings=embeddings, # type: ignore documents=chunks, metadatas=metadatas, # type: ignore ) collection = file_id collections.append(collection) self.selected_collections = collections def retrieve_chunks(self, query: str, number_of_chunks: int = 3): if len(self.selected_collections) == 0: warnings.warn( message=f"No collection is selected using all the exsisting collections, total collections : {len(self.exsisting_collections)}" ) collections = [self.client.get_collection("UNION")] self.selected_collections = [collection.name for collection in collections] else: collections = [ self.client.get_collection(collection_name) for collection_name in self.selected_collections ] query_emb = generate_embedding( [query], embedding_model="znbang/bge:small-en-v1.5-q8_0" ) retrieved_docs = [] for collection in collections: results = collection.query( query_embeddings=query_emb, n_results=5, include=["documents", "metadatas", "distances"], ) for i in range(len(results["ids"][0])): retrieved_docs.append( { "document": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": results["distances"][0][i], "collection": collection.name, } ) retrieved_docs = sorted(retrieved_docs, key=lambda x: x["distance"]) return retrieved_docs[:number_of_chunks]