import os from dotenv import load_dotenv from qdrant_client import QdrantClient from qdrant_client import models from fastembed import TextEmbedding, SparseTextEmbedding from fastembed.rerank.cross_encoder import TextCrossEncoder load_dotenv() qdrant_api_key = os.getenv("QDRANT_API_KEY") qdrant_url = os.getenv("QDRANT_URL") class Retriever: def __init__(self, collection_name="pdf_rag"): self.collection_name = collection_name self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key) self.dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") self.sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25") self.reranker = TextCrossEncoder(model_name="Xenova/ms-marco-MiniLM-L-6-v2") def retrieve(self, query: str, user_id: str): dense_query_vector = list(self.dense_model.embed([query]))[0] sparse_query = list(self.sparse_model.embed([query]))[0] sparse_query_vector = models.SparseVector( indices=sparse_query.indices, values=sparse_query.values ) user_filter = models.Filter( must=[ models.FieldCondition( key="user_id", match=models.MatchValue(value=user_id) ) ] ) results = self.client.query_points( collection_name=self.collection_name, prefetch=[ models.Prefetch( query=dense_query_vector, limit=20, using="dense", filter=user_filter ), models.Prefetch( query=sparse_query_vector, using="sparse", limit=20, filter=user_filter ) ], query=models.FusionQuery(fusion=models.Fusion.RRF), limit=20 ) texts = [ point.payload.get("text", "") for point in results.points ] rerank_scores = list(self.reranker.rerank(query, texts)) reranked_results = [] for point, score in zip(results.points, rerank_scores): reranked_results.append({ "text": point.payload.get("text"), "source": point.payload.get("source"), "pages": point.payload.get("pages"), "section": point.payload.get("section"), "original_qdrant_score": point.score, "rerank_score": float(score) }) reranked_results.sort(key=lambda x: x["rerank_score"],reverse=True) final_top_results = reranked_results[:5] return final_top_results