Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from rag.config import SETTINGS | |
| _WORD = re.compile(r"[A-Za-z0-9']+") | |
| def tokenize(text: str) -> list[str]: | |
| return _WORD.findall((text or "").lower()) | |
| class ChunkRec: | |
| chunk_id: int | |
| source_id: str | |
| text: str | |
| score: float | |
| why: str # "bm25", "dense", "rerank" | |
| class Retriever: | |
| def __init__(self) -> None: | |
| art = Path(SETTINGS.artifacts_dir) | |
| self.chunks = self._load_chunks(art / SETTINGS.chunks_jsonl) | |
| self.emb = np.load(art / SETTINGS.embeddings_npy) | |
| # BM25 | |
| tokenized = [tokenize(c["text"]) for c in self.chunks] | |
| self.bm25 = BM25Okapi(tokenized) | |
| # Dense encoder | |
| self.embedder = SentenceTransformer(SETTINGS.embed_model) | |
| # Reranker (lazy) | |
| self._reranker: CrossEncoder | None = None | |
| def _load_chunks(path: Path) -> list[dict]: | |
| out = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| out.append(json.loads(line)) | |
| return out | |
| def _bm25_search(self, query: str, k: int) -> list[ChunkRec]: | |
| scores = self.bm25.get_scores(tokenize(query)) | |
| idx = np.argsort(scores)[::-1][:k] | |
| out: list[ChunkRec] = [] | |
| for i in idx: | |
| c = self.chunks[int(i)] | |
| out.append( | |
| ChunkRec( | |
| c["chunk_id"], | |
| c["source_id"], | |
| c["text"], | |
| float(scores[int(i)]), | |
| "bm25", | |
| ) | |
| ) | |
| return out | |
| def _dense_search(self, query: str, k: int) -> list[ChunkRec]: | |
| q = self.embedder.encode([query], normalize_embeddings=True) | |
| q = np.asarray(q, dtype=np.float32)[0] | |
| # cosine similarity because embeddings normalized | |
| scores = self.emb @ q | |
| idx = np.argsort(scores)[::-1][:k] | |
| out: list[ChunkRec] = [] | |
| for i in idx: | |
| c = self.chunks[int(i)] | |
| out.append( | |
| ChunkRec( | |
| c["chunk_id"], | |
| c["source_id"], | |
| c["text"], | |
| float(scores[int(i)]), | |
| "dense", | |
| ) | |
| ) | |
| return out | |
| def _get_reranker(self) -> CrossEncoder: | |
| if self._reranker is None: | |
| self._reranker = CrossEncoder(SETTINGS.rerank_model) | |
| return self._reranker | |
| def retrieve( | |
| self, | |
| query: str, | |
| use_bm25: bool = True, | |
| use_dense: bool = True, | |
| use_rerank: bool = False, | |
| ) -> list[ChunkRec]: | |
| cands: list[ChunkRec] = [] | |
| if use_bm25: | |
| cands.extend(self._bm25_search(query, SETTINGS.top_k_bm25)) | |
| if use_dense: | |
| cands.extend(self._dense_search(query, SETTINGS.top_k_dense)) | |
| # de-dup by chunk_id keeping best score per chunk | |
| best: dict[int, ChunkRec] = {} | |
| for r in cands: | |
| prev = best.get(r.chunk_id) | |
| if prev is None or r.score > prev.score: | |
| best[r.chunk_id] = r | |
| merged = list(best.values()) | |
| merged.sort(key=lambda x: x.score, reverse=True) | |
| if use_rerank and merged: | |
| reranker = self._get_reranker() | |
| top = merged[: SETTINGS.rerank_top_n] | |
| pairs = [(query, r.text) for r in top] | |
| rr_scores = reranker.predict(pairs) | |
| for r, s in zip(top, rr_scores): | |
| r.score = float(s) | |
| r.why = "rerank" | |
| top.sort(key=lambda x: x.score, reverse=True) | |
| return top[: SETTINGS.top_k_final] | |
| return merged[: SETTINGS.top_k_final] | |