""" bm25_backends.py (Improved Version) Just what it sounds like """ from typing import List, Dict, Tuple import math import time import numpy as np from collections import defaultdict, Counter # Check for bm25s availability try: import bm25s as _bm25s _BM25S_AVAILABLE = True _BM25S_ERR = "" except Exception as _e: _bm25s = None _BM25S_AVAILABLE = False _BM25S_ERR = str(_e) class AbstractBM25Backend: """Abstract base class for a BM25 implementation.""" def __init__(self, tokenizer): self.tokenizer = tokenizer self.doc_ids: List[str] = [] def build(self, ids: List[str], texts: List[str]): raise NotImplementedError def search(self, query: str, topk: int = 300) -> List[str]: raise NotImplementedError @property def name(self) -> str: return self.__class__.__name__ class BM25SBackend(AbstractBM25Backend): """ High-performance and reliable wrapper for the 'bm25s' library. - Uses the fast `retrieve` method for efficient top-k search. - Uses stable sorting (`lexsort`) for deterministic tie-breaking. - Allows configurable k1 and b parameters. """ def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7): if not _BM25S_AVAILABLE: raise ImportError(f"bm25s library not available: {_BM25S_ERR}") super().__init__(tokenizer) self.k1 = k1 self.b = b self._bm25 = None @property def name(self) -> str: return f"BM25SBackend(k1={self.k1}, b={self.b})" def build(self, ids: List[str], texts: List[str]): from bm25s import BM25 self.doc_ids = list(ids) t0 = time.time() tokenized_corpus = [self.tokenizer(t) for t in texts] self._bm25 = BM25(k1=self.k1, b=self.b) self._bm25.index(tokenized_corpus) print(f"[{self.name}] Indexed {len(self.doc_ids):,} documents in {time.time() - t0:.2f}s") def search(self, query: str, topk: int = 300) -> List[str]: tokenized_query = self.tokenizer(query) if not tokenized_query or self._bm25 is None: return [] k = min(topk, len(self.doc_ids)) if k == 0: return [] # bm25s API compatibility: newer accepts positional list + k; older may need positional only try: doc_indices, scores = self._bm25.retrieve([tokenized_query], k=k) except TypeError: try: doc_indices, scores = self._bm25.retrieve([tokenized_query], k) except TypeError: # very old API uses 'topk' name doc_indices, scores = self._bm25.retrieve([tokenized_query], topk=k) doc_indices, scores = doc_indices[0], scores[0] mask = np.isfinite(scores) & (scores > 0) doc_indices = doc_indices[mask] scores = scores[mask] if len(doc_indices) == 0: return [] order = np.lexsort((doc_indices, -scores)) # stable: by -score then doc idx final_indices = doc_indices[order] return [self.doc_ids[int(i)] for i in final_indices] # The pure-Python fallback remains the same, as it was already reliable. class DeterministicBM25Backend(AbstractBM25Backend): """Pure-Python deterministic BM25. Slower but a good reference.""" def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7): super().__init__(tokenizer) self.k1 = k1 self.b = b self.N = 0 self.avgdl = 0.0 self.doc_lens = None self.vocab = {} self.postings = {} self.idf = None @property def name(self) -> str: return f"DeterministicBM25Backend(k1={self.k1}, b={self.b})" def build(self, ids: List[str], texts: List[str]): self.doc_ids=list(ids) self.N=len(ids) lens=np.zeros(self.N,dtype=np.int32) tmp=defaultdict(list) t0=time.time() for i, text in enumerate(texts): terms=self.tokenizer(text); lens[i]=len(terms) if not terms: continue ctr=Counter(terms) for t,tf in ctr.items(): tid=self.vocab.setdefault(t, len(self.vocab)) tmp[tid].append((i, tf)) self.doc_lens=lens self.avgdl=float(np.maximum(1,lens).mean()) V=len(self.vocab) self.idf=np.zeros(V,dtype=np.float32) self.postings={} for tid, pairs in tmp.items(): docs=np.array([d for d,_ in pairs],dtype=np.int32) tfs =np.array([tf for _,tf in pairs],dtype=np.float32) df=float(len(docs)) idf=math.log((self.N-df+0.5)/(df+0.5)+1.0) self.idf[tid]=idf self.postings[tid]=(docs,tfs) print(f"[{self.name}] Indexed {self.N:,} documents in {time.time() - t0:.2f}s") def search(self, query: str, topk: int = 300) -> List[str]: terms=self.tokenizer(query) if not terms: return [] seen: Dict[int,float] = {} for t in terms: tid=self.vocab.get(t) if tid is None: continue idf=self.idf[tid] docs,tfs=self.postings[tid] denom=tfs + self.k1*(1-self.b + self.b*(self.doc_lens[docs]/self.avgdl)) contrib = idf * (tfs*(self.k1+1)) / denom for d, c in zip(docs, contrib): seen[d]=seen.get(d,0.0)+float(c) if not seen: return [] idx=np.fromiter(seen.keys(),dtype=np.int32) scs=np.fromiter(seen.values(),dtype=np.float32) k=min(topk,len(scs)) # stable top-k: argsort with secondary key by doc index order = np.lexsort((idx, -scs)) # sort by -score, then doc idx order = order[:k] idx = idx[order] return [self.doc_ids[i] for i in idx] def get_bm25_backend(use_bm25s: bool, tokenizer, k1=1.3, b=0.7, logger=print) -> AbstractBM25Backend: """ Factory function to get the best available BM25 backend. Prefers the fast and reliable BM25SBackend, with a pure-Python fallback. """ if use_bm25s: if _BM25S_AVAILABLE: try: be = BM25SBackend(tokenizer, k1=k1, b=b) if logger: logger(f"[BM25] Using high-performance BM25S backend.") return be except Exception as e: if logger: logger(f"[BM25] BM25S failed to initialize ({e}); falling back to DeterministicBM25.") else: if logger: logger(f"[BM25] bm25s library not installed; falling back to DeterministicBM25.") if logger: logger(f"[BM25] Using pure-Python DeterministicBM25 backend.") return DeterministicBM25Backend(tokenizer, k1=k1, b=b)