pdf_rag / src /retrieval.py
LightRT's picture
Final Formatting
bb05158
import os
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client import models
from fastembed import TextEmbedding, SparseTextEmbedding
from fastembed.rerank.cross_encoder import TextCrossEncoder
load_dotenv()
qdrant_api_key = os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv("QDRANT_URL")
class Retriever:
def __init__(self, collection_name="pdf_rag"):
self.collection_name = collection_name
self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
self.dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
self.reranker = TextCrossEncoder(model_name="Xenova/ms-marco-MiniLM-L-6-v2")
def retrieve(self, query: str, user_id: str):
dense_query_vector = list(self.dense_model.embed([query]))[0]
sparse_query = list(self.sparse_model.embed([query]))[0]
sparse_query_vector = models.SparseVector(
indices=sparse_query.indices,
values=sparse_query.values
)
user_filter = models.Filter(
must=[
models.FieldCondition(
key="user_id",
match=models.MatchValue(value=user_id)
)
]
)
results = self.client.query_points(
collection_name=self.collection_name,
prefetch=[
models.Prefetch(
query=dense_query_vector,
limit=20,
using="dense",
filter=user_filter
),
models.Prefetch(
query=sparse_query_vector,
using="sparse",
limit=20,
filter=user_filter
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=20
)
texts = [
point.payload.get("text", "")
for point in results.points
]
rerank_scores = list(self.reranker.rerank(query, texts))
reranked_results = []
for point, score in zip(results.points, rerank_scores):
reranked_results.append({
"text": point.payload.get("text"),
"source": point.payload.get("source"),
"pages": point.payload.get("pages"),
"section": point.payload.get("section"),
"original_qdrant_score": point.score,
"rerank_score": float(score)
})
reranked_results.sort(key=lambda x: x["rerank_score"],reverse=True)
final_top_results = reranked_results[:5]
return final_top_results