llm-chat-project / rag /retrieve.py
DunasAnastasiia
Initial commit (Xet)
7c2e31a
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from rag.config import SETTINGS
_WORD = re.compile(r"[A-Za-z0-9']+")
def tokenize(text: str) -> list[str]:
return _WORD.findall((text or "").lower())
@dataclass
class ChunkRec:
chunk_id: int
source_id: str
text: str
score: float
why: str # "bm25", "dense", "rerank"
class Retriever:
def __init__(self) -> None:
art = Path(SETTINGS.artifacts_dir)
self.chunks = self._load_chunks(art / SETTINGS.chunks_jsonl)
self.emb = np.load(art / SETTINGS.embeddings_npy)
# BM25
tokenized = [tokenize(c["text"]) for c in self.chunks]
self.bm25 = BM25Okapi(tokenized)
# Dense encoder
self.embedder = SentenceTransformer(SETTINGS.embed_model)
# Reranker (lazy)
self._reranker: CrossEncoder | None = None
@staticmethod
def _load_chunks(path: Path) -> list[dict]:
out = []
with path.open("r", encoding="utf-8") as f:
for line in f:
out.append(json.loads(line))
return out
def _bm25_search(self, query: str, k: int) -> list[ChunkRec]:
scores = self.bm25.get_scores(tokenize(query))
idx = np.argsort(scores)[::-1][:k]
out: list[ChunkRec] = []
for i in idx:
c = self.chunks[int(i)]
out.append(
ChunkRec(
c["chunk_id"],
c["source_id"],
c["text"],
float(scores[int(i)]),
"bm25",
)
)
return out
def _dense_search(self, query: str, k: int) -> list[ChunkRec]:
q = self.embedder.encode([query], normalize_embeddings=True)
q = np.asarray(q, dtype=np.float32)[0]
# cosine similarity because embeddings normalized
scores = self.emb @ q
idx = np.argsort(scores)[::-1][:k]
out: list[ChunkRec] = []
for i in idx:
c = self.chunks[int(i)]
out.append(
ChunkRec(
c["chunk_id"],
c["source_id"],
c["text"],
float(scores[int(i)]),
"dense",
)
)
return out
def _get_reranker(self) -> CrossEncoder:
if self._reranker is None:
self._reranker = CrossEncoder(SETTINGS.rerank_model)
return self._reranker
def retrieve(
self,
query: str,
use_bm25: bool = True,
use_dense: bool = True,
use_rerank: bool = False,
) -> list[ChunkRec]:
cands: list[ChunkRec] = []
if use_bm25:
cands.extend(self._bm25_search(query, SETTINGS.top_k_bm25))
if use_dense:
cands.extend(self._dense_search(query, SETTINGS.top_k_dense))
# de-dup by chunk_id keeping best score per chunk
best: dict[int, ChunkRec] = {}
for r in cands:
prev = best.get(r.chunk_id)
if prev is None or r.score > prev.score:
best[r.chunk_id] = r
merged = list(best.values())
merged.sort(key=lambda x: x.score, reverse=True)
if use_rerank and merged:
reranker = self._get_reranker()
top = merged[: SETTINGS.rerank_top_n]
pairs = [(query, r.text) for r in top]
rr_scores = reranker.predict(pairs)
for r, s in zip(top, rr_scores):
r.score = float(s)
r.why = "rerank"
top.sort(key=lambda x: x.score, reverse=True)
return top[: SETTINGS.top_k_final]
return merged[: SETTINGS.top_k_final]