Spaces:
Sleeping
Sleeping
File size: 3,896 Bytes
7c2e31a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | from __future__ import annotations
import json
import re
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from rag.config import SETTINGS
_WORD = re.compile(r"[A-Za-z0-9']+")
def tokenize(text: str) -> list[str]:
return _WORD.findall((text or "").lower())
@dataclass
class ChunkRec:
chunk_id: int
source_id: str
text: str
score: float
why: str # "bm25", "dense", "rerank"
class Retriever:
def __init__(self) -> None:
art = Path(SETTINGS.artifacts_dir)
self.chunks = self._load_chunks(art / SETTINGS.chunks_jsonl)
self.emb = np.load(art / SETTINGS.embeddings_npy)
# BM25
tokenized = [tokenize(c["text"]) for c in self.chunks]
self.bm25 = BM25Okapi(tokenized)
# Dense encoder
self.embedder = SentenceTransformer(SETTINGS.embed_model)
# Reranker (lazy)
self._reranker: CrossEncoder | None = None
@staticmethod
def _load_chunks(path: Path) -> list[dict]:
out = []
with path.open("r", encoding="utf-8") as f:
for line in f:
out.append(json.loads(line))
return out
def _bm25_search(self, query: str, k: int) -> list[ChunkRec]:
scores = self.bm25.get_scores(tokenize(query))
idx = np.argsort(scores)[::-1][:k]
out: list[ChunkRec] = []
for i in idx:
c = self.chunks[int(i)]
out.append(
ChunkRec(
c["chunk_id"],
c["source_id"],
c["text"],
float(scores[int(i)]),
"bm25",
)
)
return out
def _dense_search(self, query: str, k: int) -> list[ChunkRec]:
q = self.embedder.encode([query], normalize_embeddings=True)
q = np.asarray(q, dtype=np.float32)[0]
# cosine similarity because embeddings normalized
scores = self.emb @ q
idx = np.argsort(scores)[::-1][:k]
out: list[ChunkRec] = []
for i in idx:
c = self.chunks[int(i)]
out.append(
ChunkRec(
c["chunk_id"],
c["source_id"],
c["text"],
float(scores[int(i)]),
"dense",
)
)
return out
def _get_reranker(self) -> CrossEncoder:
if self._reranker is None:
self._reranker = CrossEncoder(SETTINGS.rerank_model)
return self._reranker
def retrieve(
self,
query: str,
use_bm25: bool = True,
use_dense: bool = True,
use_rerank: bool = False,
) -> list[ChunkRec]:
cands: list[ChunkRec] = []
if use_bm25:
cands.extend(self._bm25_search(query, SETTINGS.top_k_bm25))
if use_dense:
cands.extend(self._dense_search(query, SETTINGS.top_k_dense))
# de-dup by chunk_id keeping best score per chunk
best: dict[int, ChunkRec] = {}
for r in cands:
prev = best.get(r.chunk_id)
if prev is None or r.score > prev.score:
best[r.chunk_id] = r
merged = list(best.values())
merged.sort(key=lambda x: x.score, reverse=True)
if use_rerank and merged:
reranker = self._get_reranker()
top = merged[: SETTINGS.rerank_top_n]
pairs = [(query, r.text) for r in top]
rr_scores = reranker.predict(pairs)
for r, s in zip(top, rr_scores):
r.score = float(s)
r.why = "rerank"
top.sort(key=lambda x: x.score, reverse=True)
return top[: SETTINGS.top_k_final]
return merged[: SETTINGS.top_k_final]
|