from __future__ import annotations import json import re from dataclasses import dataclass from pathlib import Path import numpy as np from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer, CrossEncoder from rag.config import SETTINGS _WORD = re.compile(r"[A-Za-z0-9']+") def tokenize(text: str) -> list[str]: return _WORD.findall((text or "").lower()) @dataclass class ChunkRec: chunk_id: int source_id: str text: str score: float why: str # "bm25", "dense", "rerank" class Retriever: def __init__(self) -> None: art = Path(SETTINGS.artifacts_dir) self.chunks = self._load_chunks(art / SETTINGS.chunks_jsonl) self.emb = np.load(art / SETTINGS.embeddings_npy) # BM25 tokenized = [tokenize(c["text"]) for c in self.chunks] self.bm25 = BM25Okapi(tokenized) # Dense encoder self.embedder = SentenceTransformer(SETTINGS.embed_model) # Reranker (lazy) self._reranker: CrossEncoder | None = None @staticmethod def _load_chunks(path: Path) -> list[dict]: out = [] with path.open("r", encoding="utf-8") as f: for line in f: out.append(json.loads(line)) return out def _bm25_search(self, query: str, k: int) -> list[ChunkRec]: scores = self.bm25.get_scores(tokenize(query)) idx = np.argsort(scores)[::-1][:k] out: list[ChunkRec] = [] for i in idx: c = self.chunks[int(i)] out.append( ChunkRec( c["chunk_id"], c["source_id"], c["text"], float(scores[int(i)]), "bm25", ) ) return out def _dense_search(self, query: str, k: int) -> list[ChunkRec]: q = self.embedder.encode([query], normalize_embeddings=True) q = np.asarray(q, dtype=np.float32)[0] # cosine similarity because embeddings normalized scores = self.emb @ q idx = np.argsort(scores)[::-1][:k] out: list[ChunkRec] = [] for i in idx: c = self.chunks[int(i)] out.append( ChunkRec( c["chunk_id"], c["source_id"], c["text"], float(scores[int(i)]), "dense", ) ) return out def _get_reranker(self) -> CrossEncoder: if self._reranker is None: self._reranker = CrossEncoder(SETTINGS.rerank_model) return self._reranker def retrieve( self, query: str, use_bm25: bool = True, use_dense: bool = True, use_rerank: bool = False, ) -> list[ChunkRec]: cands: list[ChunkRec] = [] if use_bm25: cands.extend(self._bm25_search(query, SETTINGS.top_k_bm25)) if use_dense: cands.extend(self._dense_search(query, SETTINGS.top_k_dense)) # de-dup by chunk_id keeping best score per chunk best: dict[int, ChunkRec] = {} for r in cands: prev = best.get(r.chunk_id) if prev is None or r.score > prev.score: best[r.chunk_id] = r merged = list(best.values()) merged.sort(key=lambda x: x.score, reverse=True) if use_rerank and merged: reranker = self._get_reranker() top = merged[: SETTINGS.rerank_top_n] pairs = [(query, r.text) for r in top] rr_scores = reranker.predict(pairs) for r, s in zip(top, rr_scores): r.score = float(s) r.why = "rerank" top.sort(key=lambda x: x.score, reverse=True) return top[: SETTINGS.top_k_final] return merged[: SETTINGS.top_k_final]