DocuMind-API / components /Database.py
ashishbangwal's picture
internal server error resolved
586cd83
"""
Contain Wrapper Class for ChormaDB client, that can process and store documents and retrive document chunks.
"""
# for chromaDB
__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from io import BytesIO
from typing import List
from typing_extensions import Literal
import uuid
import warnings
import chromadb
import re
from .utils import (
generate_file_id,
chunk_document,
generate_embedding,
extract_content_from_docx,
extract_content_from_pdf,
)
class AdvancedClient:
def __init__(self, vector_database_path: str = "vectorDB") -> None:
self.client = chromadb.PersistentClient(path=vector_database_path)
self.exsisting_collections = [
collection.name for collection in self.client.list_collections()
]
self.selected_collections: List[str] = []
def create_or_get_collection(
self,
file_names: List[str],
file_types: List[Literal["pdf", "docx"]],
file_datas,
):
collections = []
for data in zip(file_names, file_types, file_datas):
file_name, file_type, file_data = data
file_id = generate_file_id(file_bytes=file_data)
file_exisis = file_id in self.exsisting_collections
if file_exisis:
collection = file_id
else:
collection = self.client.create_collection(name=file_id)
file_buffer = BytesIO(file_data)
if file_type == "pdf":
document, pil_images = extract_content_from_pdf(file_buffer)
chunks = chunk_document(document)
ids = [f"{uuid.uuid4()}_id_{x}" for x in range(1, len(chunks) + 1)]
embeddings = generate_embedding(
chunks, embedding_model="znbang/bge:small-en-v1.5-q8_0"
)
metadatas = []
for chunk in chunks:
imgs_found = re.findall(
pattern=r"<img\s+src='([^']*)'>", string=chunk
)
chunk_imgs = []
if len(imgs_found) > 0:
for img in imgs_found:
chunk_imgs.append(pil_images[int(img)])
metadatas.append(
{"images": str(chunk_imgs), "file_name": file_name}
)
elif file_type == "docx":
document = extract_content_from_docx(file_buffer)
chunks = chunk_document(document)
ids = [f"{uuid.uuid4()}_id_{x}" for x in range(1, len(chunks) + 1)]
embeddings = generate_embedding(
chunks, embedding_model="znbang/bge:small-en-v1.5-q8_0"
)
metadatas = [{"file_name": file_name} for _ in chunks]
else:
raise Exception(
f"Given format '.{file_type}' is currently not supported."
)
collection.add(
ids=ids,
embeddings=embeddings, # type: ignore
documents=chunks,
metadatas=metadatas, # type: ignore
)
collection = file_id
collections.append(collection)
self.selected_collections = collections
def retrieve_chunks(self, query: str, number_of_chunks: int = 3):
if len(self.selected_collections) == 0:
warnings.warn(
message=f"No collection is selected using all the exsisting collections, total collections : {len(self.exsisting_collections)}"
)
collections = [self.client.get_collection("UNION")]
self.selected_collections = [collection.name for collection in collections]
else:
collections = [
self.client.get_collection(collection_name)
for collection_name in self.selected_collections
]
query_emb = generate_embedding(
[query], embedding_model="znbang/bge:small-en-v1.5-q8_0"
)
retrieved_docs = []
for collection in collections:
results = collection.query(
query_embeddings=query_emb,
n_results=5,
include=["documents", "metadatas", "distances"],
)
for i in range(len(results["ids"][0])):
retrieved_docs.append(
{
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": results["distances"][0][i],
"collection": collection.name,
}
)
retrieved_docs = sorted(retrieved_docs, key=lambda x: x["distance"])
return retrieved_docs[:number_of_chunks]