File size: 2,770 Bytes
9cc7f8d
 
 
 
77d7fca
 
 
9cc7f8d
 
 
 
77d7fca
 
bb05158
 
9cc7f8d
bb05158
77d7fca
 
 
 
 
 
bb05158
77d7fca
 
9cc7f8d
bb05158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77d7fca
 
 
9cc7f8d
77d7fca
9cc7f8d
 
 
 
 
 
 
 
 
 
bb05158
77d7fca
 
 
bb05158
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client import models
from fastembed import TextEmbedding, SparseTextEmbedding
from fastembed.rerank.cross_encoder import TextCrossEncoder

load_dotenv()

qdrant_api_key = os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv("QDRANT_URL")


class Retriever:
    def __init__(self, collection_name="pdf_rag"):
        self.collection_name = collection_name
        self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)

        self.dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
        self.sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")

        self.reranker = TextCrossEncoder(model_name="Xenova/ms-marco-MiniLM-L-6-v2")

    def retrieve(self, query: str, user_id: str):
        dense_query_vector = list(self.dense_model.embed([query]))[0]

        sparse_query = list(self.sparse_model.embed([query]))[0]
        sparse_query_vector = models.SparseVector(
            indices=sparse_query.indices,
            values=sparse_query.values
        )

        user_filter = models.Filter(
            must=[
                models.FieldCondition(
                    key="user_id",
                    match=models.MatchValue(value=user_id)
                )
            ]
        )

        results = self.client.query_points(
            collection_name=self.collection_name,
            prefetch=[
                models.Prefetch(
                    query=dense_query_vector,
                    limit=20,
                    using="dense",
                    filter=user_filter
                ),
                models.Prefetch(
                    query=sparse_query_vector,
                    using="sparse",
                    limit=20,
                    filter=user_filter
                )
            ],
            query=models.FusionQuery(fusion=models.Fusion.RRF),
            limit=20
        )

        texts = [
            point.payload.get("text", "")
            for point in results.points
        ]

        rerank_scores = list(self.reranker.rerank(query, texts))

        reranked_results = []

        for point, score in zip(results.points, rerank_scores):
            reranked_results.append({
                "text": point.payload.get("text"),
                "source": point.payload.get("source"),
                "pages": point.payload.get("pages"),
                "section": point.payload.get("section"),
                "original_qdrant_score": point.score,
                "rerank_score": float(score)
            })

        reranked_results.sort(key=lambda x: x["rerank_score"],reverse=True)

        final_top_results = reranked_results[:5]

        return final_top_results