File size: 7,117 Bytes
11ba2bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """Retrieval: BM25 + Dense (FAISS) + RRF fusion + cross-encoder reranking."""
from __future__ import annotations
import os
import pickle
import re
from typing import Optional
import numpy as np
import pandas as pd
from src.citations import Citation
from src.config import (
BM25_FILE, CHUNKS_FILE, EMBED_MODEL, FAISS_FILE,
RRF_K, RERANK_MODEL, TOP_K_BM25, TOP_K_DENSE, TOP_K_FUSED, TOP_N_FINAL,
)
# ββ tokeniser ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_.:]*|\d+")
_CAMEL_RE = re.compile(r"(?<!^)(?=[A-Z])")
_STOP = {"the","a","an","of","to","in","is","are","and","or","this","that","it","be"}
def _tokenize(text: str) -> list[str]:
tokens = _TOKEN_RE.findall(text)
out: list[str] = []
for t in tokens:
tl = t.lower()
if tl in _STOP:
continue
out.append(tl)
parts = _CAMEL_RE.split(t)
if len(parts) > 1:
out.extend(p.lower() for p in parts if p and p.lower() not in _STOP)
for sub in re.split(r"[._:]+", t):
if sub and sub.lower() not in _STOP and sub.lower() != tl:
out.append(sub.lower())
return out
# ββ lazy singletons βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_chunks_df: Optional[pd.DataFrame] = None
_bm25_index = None
_faiss_index = None
_embed_model = None
_rerank_model = None
def _load_chunks() -> pd.DataFrame:
global _chunks_df
if _chunks_df is None:
if not os.path.exists(CHUNKS_FILE):
raise FileNotFoundError(
f"{CHUNKS_FILE} not found. Run `python build_index.py` first."
)
_chunks_df = pd.read_parquet(CHUNKS_FILE)
return _chunks_df
def _load_bm25():
global _bm25_index
if _bm25_index is None:
if not os.path.exists(BM25_FILE):
raise FileNotFoundError(f"{BM25_FILE} not found.")
with open(BM25_FILE, "rb") as f:
_bm25_index = pickle.load(f)
return _bm25_index
def _load_faiss():
global _faiss_index
if _faiss_index is None:
import faiss # noqa: PLC0415
if not os.path.exists(FAISS_FILE):
raise FileNotFoundError(f"{FAISS_FILE} not found.")
_faiss_index = faiss.read_index(FAISS_FILE)
return _faiss_index
def _load_embed():
global _embed_model
if _embed_model is None:
from sentence_transformers import SentenceTransformer # noqa: PLC0415
_embed_model = SentenceTransformer(EMBED_MODEL)
return _embed_model
def _load_reranker():
global _rerank_model
if _rerank_model is None:
from sentence_transformers import CrossEncoder # noqa: PLC0415
_rerank_model = CrossEncoder(RERANK_MODEL)
return _rerank_model
def indices_ready() -> bool:
return all(os.path.exists(p) for p in (CHUNKS_FILE, BM25_FILE, FAISS_FILE))
# ββ retrieval methods βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _bm25_search(query: str, top_k: int) -> list[tuple[int, float]]:
"""Returns [(chunk_id, score), ...]."""
import bm25s # noqa: PLC0415
bm25 = _load_bm25()
query_tokens_arr = bm25s.tokenize([" ".join(_tokenize(query))])
results, scores = bm25.retrieve(query_tokens_arr, k=top_k)
return list(zip(results[0].tolist(), scores[0].tolist()))
def _dense_search(query: str, top_k: int) -> list[tuple[int, float]]:
"""Returns [(chunk_id, score), ...]."""
model = _load_embed()
index = _load_faiss()
# BGE models expect a query prefix
vec = model.encode(f"Represent this sentence for searching relevant passages: {query}",
normalize_embeddings=True).reshape(1, -1).astype("float32")
scores, ids = index.search(vec, top_k)
return [(int(i), float(s)) for i, s in zip(ids[0], scores[0]) if i >= 0]
def _rrf_fuse(
bm25_hits: list[tuple[int, float]],
dense_hits: list[tuple[int, float]],
k: int = RRF_K,
top_n: int = TOP_K_FUSED,
) -> list[tuple[int, float]]:
scores: dict[int, float] = {}
for rank, (cid, _) in enumerate(bm25_hits):
scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1)
for rank, (cid, _) in enumerate(dense_hits):
scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1)
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return ranked[:top_n]
def _rerank(query: str, hits: list[tuple[int, float]], top_n: int, df: pd.DataFrame) -> list[tuple[int, float]]:
reranker = _load_reranker()
pairs = [(query, df.loc[cid, "text"]) for cid, _ in hits]
scores = reranker.predict(pairs)
ranked = sorted(zip([cid for cid, _ in hits], scores), key=lambda x: x[1], reverse=True)
return [(int(cid), float(s)) for cid, s in ranked[:top_n]]
# ββ public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class HybridRetriever:
def __init__(
self,
use_bm25: bool = True,
use_dense: bool = True,
use_rerank: bool = True,
top_n: int = TOP_N_FINAL,
):
self.use_bm25 = use_bm25
self.use_dense = use_dense
self.use_rerank = use_rerank
self.top_n = top_n
def retrieve(self, query: str) -> list[Citation]:
df = _load_chunks()
bm25_hits: list[tuple[int, float]] = []
dense_hits: list[tuple[int, float]] = []
if self.use_bm25:
bm25_hits = _bm25_search(query, TOP_K_BM25)
if self.use_dense:
dense_hits = _dense_search(query, TOP_K_DENSE)
if self.use_bm25 and self.use_dense:
fused = _rrf_fuse(bm25_hits, dense_hits)
elif self.use_bm25:
fused = bm25_hits[:TOP_K_FUSED]
elif self.use_dense:
fused = dense_hits[:TOP_K_FUSED]
else:
return []
if self.use_rerank and len(fused) > 0:
final = _rerank(query, fused, self.top_n, df)
else:
final = fused[:self.top_n]
citations: list[Citation] = []
for rank, (cid, score) in enumerate(final, start=1):
row = df.loc[cid]
citations.append(Citation(
id=rank,
chunk_id=int(cid),
source_url=str(row["source_url"]),
page_title=str(row["page_title"]),
section=str(row.get("section", "")),
snippet=str(row["text"])[:600],
score=float(score),
))
return citations
|