ragent-chatbot / vector_db /qdrant_db.py
shafiqul1357's picture
upload source code
633bb91 verified
raw
history blame
13.4 kB
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import json
import hashlib
import pandas as pd
from config import Config
from utils.nltk import NLTK
from typing import List, Dict
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from utils.normalizer import Normalizer
from qdrant_client.models import ScoredPoint
from langchain_core.documents import Document
from vector_db.chunker import DocumentChunker
from vector_db.data_embedder import BAAIEmbedder
from qdrant_client.models import Distance, VectorParams, PointStruct
from qdrant_client.http.models import Filter, FieldCondition, MatchText
from qdrant_client.models import TextIndexParams, TextIndexType, TokenizerType
from langchain_community.document_loaders import (
PDFPlumberLoader,
UnstructuredWordDocumentLoader,
UnstructuredPowerPointLoader,
UnstructuredExcelLoader,
TextLoader,
CSVLoader,
JSONLoader
)
load_dotenv()
class QdrantDBClient:
def __init__(self):
self.collection_name = Config.COLLECTION_NAME
self.client = QdrantClient(url=os.getenv('QDRANT_URL'), api_key=os.getenv('QDRANT_API_KEY')) # Qdrant - Cloud
#self.client = QdrantClient(path=Config.QDRANT_PERSIST_PATH) # Qdrant - Local
self.embedder = BAAIEmbedder()
self.chunker = DocumentChunker()
self.normalizer = Normalizer()
self.nltk = NLTK()
if not self.client.collection_exists(self.collection_name):
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.embedder.model.get_sentence_embedding_dimension(),
distance=Distance.COSINE,
)
)
# Optional performance optimization
self.client.update_collection(
collection_name=self.collection_name,
optimizers_config={"default_segment_number": 2}
)
# Add BM25 support on 'tokenized_text' field
self.client.create_payload_index(
collection_name=self.collection_name,
field_name="tokenized_text",
field_schema=TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.WHITESPACE,
min_token_len=1,
max_token_len=20,
lowercase=False
)
)
def tokenize_for_bm25(self, text: str) -> str:
norm_text = self.normalizer.normalize_text(text)
tokens = norm_text.split()
filtered_tokens = [t for t in tokens if t.lower() not in self.nltk.stopwords]
return " ".join(filtered_tokens)
def get_jq_schema(self, file_path: str) -> str:
"""
Dynamically determines the jq_schema based on whether the JSON root is a list or a dict.
Handles:
- Root list: [. {...}, {...}]
- Root dict with list key: { "key": [ {...}, {...} ] }
Raises:
ValueError: If no valid list is found.
"""
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return ".[]"
elif isinstance(data, dict):
for key, value in data.items():
if isinstance(value, list):
return f".{key}[]"
raise ValueError("No list found in the top-level JSON object.")
else:
raise ValueError("Unsupported JSON structure: must be list or dict")
def load_excel_with_headers(self, file_path):
df = pd.read_excel(file_path)
docs = []
for i, row in df.iterrows():
text = "\n".join([f"{col}: {row[col]}" for col in df.columns])
metadata = {"source": file_path, "row_index": i}
docs.append(Document(page_content=text, metadata=metadata))
return docs
def load_and_chunk_docs(self, file_path: str) -> List[dict]:
ext = os.path.splitext(file_path)[1]
if ext == ".pdf":
docs = PDFPlumberLoader(file_path).load()
elif ext == ".docx":
docs = UnstructuredWordDocumentLoader(file_path).load()
elif ext == ".xlsx":
#docs = UnstructuredExcelLoader(file_path).load()
docs = self.load_excel_with_headers(file_path)
elif ext == ".pptx":
docs = UnstructuredPowerPointLoader(file_path).load()
elif ext == ".txt":
docs = TextLoader(file_path, encoding="utf-8").load()
elif ext == ".csv":
docs = CSVLoader(file_path).load()
elif ext == ".json":
docs = JSONLoader(file_path, jq_schema=self.get_jq_schema(file_path), text_content=False).load()
else:
return []
# Add source metadata to each Document
for doc in docs:
doc.metadata["source"] = os.path.basename(file_path)
return self.chunker.split_documents(docs)
def hash_text(self, text: str) -> str:
return hashlib.md5(text.encode('utf-8')).hexdigest()
def insert_chunks(self, chunk_dicts: List[dict]):
seen_hashes = set()
all_points = []
texts = [self.normalizer.normalize_text(d["text"]) for d in chunk_dicts]
embeddings = self.embedder.embed_documents(texts)
for i, chunk in enumerate(chunk_dicts):
text = self.normalizer.normalize_text(chunk["text"])
chunk_hash = self.hash_text(text)
if chunk_hash in seen_hashes:
continue
seen_hashes.add(chunk_hash)
tokenized_text = self.tokenize_for_bm25(text)
all_points.append(
PointStruct(
id=chunk["id"],
vector=embeddings[i],
payload={
"text": text,
"tokenized_text": tokenized_text,
**chunk["metadata"]
}
)
)
for i in range(0, len(all_points), Config.BATCH_SIZE):
self.client.upsert(collection_name=self.collection_name, points=all_points[i:i + Config.BATCH_SIZE])
def search(self, query: str, top_k: int = Config.TOP_K) -> List[Document]:
query = self.normalizer.normalize_text(query)
query_embedding = self.embedder.embed_query(query)
query_tokens = self.tokenize_for_bm25(query).split()
# print(f"\n🔍 Query: {query}")
# print(f"🔑 Query Tokens: {query_tokens}")
# BM25 Search
bm25_results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=Filter(
should=[
FieldCondition(
key="tokenized_text",
match=MatchText(text=token)
) for token in query_tokens
]
),
limit=top_k
)[0]
bm25_dict = {
pt.payload.get("text", ""): {
"source": "BM25",
"bm25_score": getattr(pt, "score", 0.0), # Handle missing scores
"vector_score": 0.0,
"metadata": pt.payload or {}
}
for pt in bm25_results
}
# print(f"\n### BM25 Results ({len(bm25_dict)}):")
# for i, (text, info) in enumerate(bm25_dict.items(), 1):
# print(f"{i}. {text[:100]}... | BM25 Score: {info['bm25_score']:.4f}")
# Vector Search (using query_points instead of deprecated search)
vector_results: List[ScoredPoint] = self.client.query_points(
collection_name=self.collection_name,
query=query_embedding,
limit=top_k,
with_payload=True,
with_vectors=False
).points
vector_dict = {
pt.payload.get("text", ""): {
"source": "Vector",
"bm25_score": 0.0,
"vector_score": getattr(pt, "score", 0.0), # Handle missing scores
"metadata": pt.payload or {}
}
for pt in vector_results
}
# print(f"\n### Vector Results ({len(vector_dict)}):")
# for i, (text, info) in enumerate(vector_dict.items(), 1):
# print(f"{i}. {text[:100]}... | Vector Score: {info['vector_score']:.4f}")
# Merge & Deduplicate Results
combined_results: Dict[str, Dict] = {}
for text, info in bm25_dict.items():
combined_results[text] = {
"source": info["source"],
"bm25_score": info["bm25_score"],
"vector_score": 0.0,
"metadata": info["metadata"]
}
for text, info in vector_dict.items():
if text in combined_results:
combined_results[text]["source"] = "Hybrid"
combined_results[text]["vector_score"] = info["vector_score"]
else:
combined_results[text] = {
"source": info["source"],
"bm25_score": 0.0,
"vector_score": info["vector_score"],
"metadata": info["metadata"]
}
# Compute Hybrid Score
for text in combined_results:
combined_results[text]["final_score"] = (
Config.ALPHA * combined_results[text]["bm25_score"]
+ (1 - Config.ALPHA) * combined_results[text]["vector_score"]
)
# Sort and return as LangChain Documents
sorted_results = sorted(combined_results.items(), key=lambda x: x[1]["final_score"], reverse=True)
# print(f"\n### Combined Results (Sorted by Final Score):")
# for i, (text, info) in enumerate(sorted_results, 1):
# print(f"{i}. {text[:100]}... | Final Score: {info['final_score']:.4f} | "
# f"BM25: {info['bm25_score']:.4f} | Vector: {info['vector_score']:.4f} | Source: {info['source']}")
return [
Document(
page_content=text,
metadata={
**info["metadata"],
"source": info["source"],
"bm25_score": info["bm25_score"],
"vector_score": info["vector_score"],
"final_score": info["final_score"]
}
)
for text, info in sorted_results # Don't Remove zero-score docs
#for text, info in sorted_results if info["final_score"] > 0 # Remove zero-score docs
]
def export_all_documents(self, output_dir: str = Config.STORED_CHUNK_DIR):
"""Export all inserted documents from Qdrant grouped by source."""
os.makedirs(output_dir, exist_ok=True)
file_text_map = {}
next_offset = None
while True:
points, next_offset = self.client.scroll(
collection_name=self.collection_name,
with_payload=True,
with_vectors=False,
limit=1000, # You can tune this batch size
offset=next_offset
)
for pt in points:
payload = pt.payload or {}
source = payload.get("source", "unknown_file.txt")
text = payload.get("text", "")
if not text.strip():
continue
file_text_map.setdefault(source, []).append((text, payload.get("chunk_order", 0)))
if next_offset is None:
break
# Write all collected texts grouped by file name
for source, chunks in file_text_map.items():
file_name = os.path.splitext(os.path.basename(source))[0]
file_path = os.path.join(output_dir, f"{file_name}.txt")
# Sort by chunk_order
sorted_chunks = sorted(chunks, key=lambda x: x[1])
with open(file_path, "w", encoding="utf-8") as f:
for chunk_text, chunk_order in sorted_chunks:
f.write(f"### Chunk Order: {chunk_order}\n")
f.write(chunk_text.strip() + "\n\n---\n\n")
print(f"### Exported {len(file_text_map)} source files to '{output_dir}'")
def clear_qdrant_db(self):
if self.client.collection_exists(self.collection_name):
self.client.delete_collection(collection_name=self.collection_name) # deletes full collection
print("### All data is removed")
if __name__ == "__main__":
qdrant_db_client = QdrantDBClient()
data_dir = Config.DATA_DIR
for filename in os.listdir(data_dir):
file_path = os.path.join(data_dir, filename)
ext = os.path.splitext(filename)[1].lower()
if os.path.isfile(file_path) and ext in Config.FILE_EXTENSIONS:
print(f"📄 Processing: {filename}")
chunk_dicts = qdrant_db_client.load_and_chunk_docs(file_path)
qdrant_db_client.insert_chunks(chunk_dicts)
print(f"### Total documents in collection: {qdrant_db_client.client.count(qdrant_db_client.collection_name)}")
qdrant_db_client.export_all_documents()
#qdrant_db_client.clear_qdrant_db()
query = "What is the full form of K12HSN?"
docs = qdrant_db_client.search(query)
print(f"\n### Retrieved {len(docs)} results:")
for i, doc in enumerate(docs, 1):
print(f"\n{i}. {doc.page_content[:]}...")