File size: 3,896 Bytes
7c2e31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations

import json
import re
from dataclasses import dataclass
from pathlib import Path

import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder

from rag.config import SETTINGS

_WORD = re.compile(r"[A-Za-z0-9']+")


def tokenize(text: str) -> list[str]:
    return _WORD.findall((text or "").lower())


@dataclass
class ChunkRec:
    chunk_id: int
    source_id: str
    text: str
    score: float
    why: str  # "bm25", "dense", "rerank"


class Retriever:
    def __init__(self) -> None:
        art = Path(SETTINGS.artifacts_dir)
        self.chunks = self._load_chunks(art / SETTINGS.chunks_jsonl)
        self.emb = np.load(art / SETTINGS.embeddings_npy)

        # BM25
        tokenized = [tokenize(c["text"]) for c in self.chunks]
        self.bm25 = BM25Okapi(tokenized)

        # Dense encoder
        self.embedder = SentenceTransformer(SETTINGS.embed_model)

        # Reranker (lazy)
        self._reranker: CrossEncoder | None = None

    @staticmethod
    def _load_chunks(path: Path) -> list[dict]:
        out = []
        with path.open("r", encoding="utf-8") as f:
            for line in f:
                out.append(json.loads(line))
        return out

    def _bm25_search(self, query: str, k: int) -> list[ChunkRec]:
        scores = self.bm25.get_scores(tokenize(query))
        idx = np.argsort(scores)[::-1][:k]
        out: list[ChunkRec] = []
        for i in idx:
            c = self.chunks[int(i)]
            out.append(
                ChunkRec(
                    c["chunk_id"],
                    c["source_id"],
                    c["text"],
                    float(scores[int(i)]),
                    "bm25",
                )
            )
        return out

    def _dense_search(self, query: str, k: int) -> list[ChunkRec]:
        q = self.embedder.encode([query], normalize_embeddings=True)
        q = np.asarray(q, dtype=np.float32)[0]
        # cosine similarity because embeddings normalized
        scores = self.emb @ q
        idx = np.argsort(scores)[::-1][:k]
        out: list[ChunkRec] = []
        for i in idx:
            c = self.chunks[int(i)]
            out.append(
                ChunkRec(
                    c["chunk_id"],
                    c["source_id"],
                    c["text"],
                    float(scores[int(i)]),
                    "dense",
                )
            )
        return out

    def _get_reranker(self) -> CrossEncoder:
        if self._reranker is None:
            self._reranker = CrossEncoder(SETTINGS.rerank_model)
        return self._reranker

    def retrieve(
        self,
        query: str,
        use_bm25: bool = True,
        use_dense: bool = True,
        use_rerank: bool = False,
    ) -> list[ChunkRec]:
        cands: list[ChunkRec] = []
        if use_bm25:
            cands.extend(self._bm25_search(query, SETTINGS.top_k_bm25))
        if use_dense:
            cands.extend(self._dense_search(query, SETTINGS.top_k_dense))

        # de-dup by chunk_id keeping best score per chunk
        best: dict[int, ChunkRec] = {}
        for r in cands:
            prev = best.get(r.chunk_id)
            if prev is None or r.score > prev.score:
                best[r.chunk_id] = r
        merged = list(best.values())
        merged.sort(key=lambda x: x.score, reverse=True)

        if use_rerank and merged:
            reranker = self._get_reranker()
            top = merged[: SETTINGS.rerank_top_n]
            pairs = [(query, r.text) for r in top]
            rr_scores = reranker.predict(pairs)
            for r, s in zip(top, rr_scores):
                r.score = float(s)
                r.why = "rerank"
            top.sort(key=lambda x: x.score, reverse=True)
            return top[: SETTINGS.top_k_final]

        return merged[: SETTINGS.top_k_final]