Semantic-Retrieval-2nd-place / bm25_backends.py
yarden077's picture
uploading 2nd place model
0f5ecaf verified
"""
bm25_backends.py (Improved Version)
Just what it sounds like
"""
from typing import List, Dict, Tuple
import math
import time
import numpy as np
from collections import defaultdict, Counter
# Check for bm25s availability
try:
import bm25s as _bm25s
_BM25S_AVAILABLE = True
_BM25S_ERR = ""
except Exception as _e:
_bm25s = None
_BM25S_AVAILABLE = False
_BM25S_ERR = str(_e)
class AbstractBM25Backend:
"""Abstract base class for a BM25 implementation."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.doc_ids: List[str] = []
def build(self, ids: List[str], texts: List[str]):
raise NotImplementedError
def search(self, query: str, topk: int = 300) -> List[str]:
raise NotImplementedError
@property
def name(self) -> str:
return self.__class__.__name__
class BM25SBackend(AbstractBM25Backend):
"""
High-performance and reliable wrapper for the 'bm25s' library.
- Uses the fast `retrieve` method for efficient top-k search.
- Uses stable sorting (`lexsort`) for deterministic tie-breaking.
- Allows configurable k1 and b parameters.
"""
def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7):
if not _BM25S_AVAILABLE:
raise ImportError(f"bm25s library not available: {_BM25S_ERR}")
super().__init__(tokenizer)
self.k1 = k1
self.b = b
self._bm25 = None
@property
def name(self) -> str:
return f"BM25SBackend(k1={self.k1}, b={self.b})"
def build(self, ids: List[str], texts: List[str]):
from bm25s import BM25
self.doc_ids = list(ids)
t0 = time.time()
tokenized_corpus = [self.tokenizer(t) for t in texts]
self._bm25 = BM25(k1=self.k1, b=self.b)
self._bm25.index(tokenized_corpus)
print(f"[{self.name}] Indexed {len(self.doc_ids):,} documents in {time.time() - t0:.2f}s")
def search(self, query: str, topk: int = 300) -> List[str]:
tokenized_query = self.tokenizer(query)
if not tokenized_query or self._bm25 is None:
return []
k = min(topk, len(self.doc_ids))
if k == 0:
return []
# bm25s API compatibility: newer accepts positional list + k; older may need positional only
try:
doc_indices, scores = self._bm25.retrieve([tokenized_query], k=k)
except TypeError:
try:
doc_indices, scores = self._bm25.retrieve([tokenized_query], k)
except TypeError:
# very old API uses 'topk' name
doc_indices, scores = self._bm25.retrieve([tokenized_query], topk=k)
doc_indices, scores = doc_indices[0], scores[0]
mask = np.isfinite(scores) & (scores > 0)
doc_indices = doc_indices[mask]
scores = scores[mask]
if len(doc_indices) == 0:
return []
order = np.lexsort((doc_indices, -scores)) # stable: by -score then doc idx
final_indices = doc_indices[order]
return [self.doc_ids[int(i)] for i in final_indices]
# The pure-Python fallback remains the same, as it was already reliable.
class DeterministicBM25Backend(AbstractBM25Backend):
"""Pure-Python deterministic BM25. Slower but a good reference."""
def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7):
super().__init__(tokenizer)
self.k1 = k1
self.b = b
self.N = 0
self.avgdl = 0.0
self.doc_lens = None
self.vocab = {}
self.postings = {}
self.idf = None
@property
def name(self) -> str:
return f"DeterministicBM25Backend(k1={self.k1}, b={self.b})"
def build(self, ids: List[str], texts: List[str]):
self.doc_ids=list(ids)
self.N=len(ids)
lens=np.zeros(self.N,dtype=np.int32)
tmp=defaultdict(list)
t0=time.time()
for i, text in enumerate(texts):
terms=self.tokenizer(text); lens[i]=len(terms)
if not terms: continue
ctr=Counter(terms)
for t,tf in ctr.items():
tid=self.vocab.setdefault(t, len(self.vocab))
tmp[tid].append((i, tf))
self.doc_lens=lens
self.avgdl=float(np.maximum(1,lens).mean())
V=len(self.vocab)
self.idf=np.zeros(V,dtype=np.float32)
self.postings={}
for tid, pairs in tmp.items():
docs=np.array([d for d,_ in pairs],dtype=np.int32)
tfs =np.array([tf for _,tf in pairs],dtype=np.float32)
df=float(len(docs))
idf=math.log((self.N-df+0.5)/(df+0.5)+1.0)
self.idf[tid]=idf
self.postings[tid]=(docs,tfs)
print(f"[{self.name}] Indexed {self.N:,} documents in {time.time() - t0:.2f}s")
def search(self, query: str, topk: int = 300) -> List[str]:
terms=self.tokenizer(query)
if not terms: return []
seen: Dict[int,float] = {}
for t in terms:
tid=self.vocab.get(t)
if tid is None: continue
idf=self.idf[tid]
docs,tfs=self.postings[tid]
denom=tfs + self.k1*(1-self.b + self.b*(self.doc_lens[docs]/self.avgdl))
contrib = idf * (tfs*(self.k1+1)) / denom
for d, c in zip(docs, contrib):
seen[d]=seen.get(d,0.0)+float(c)
if not seen: return []
idx=np.fromiter(seen.keys(),dtype=np.int32)
scs=np.fromiter(seen.values(),dtype=np.float32)
k=min(topk,len(scs))
# stable top-k: argsort with secondary key by doc index
order = np.lexsort((idx, -scs)) # sort by -score, then doc idx
order = order[:k]
idx = idx[order]
return [self.doc_ids[i] for i in idx]
def get_bm25_backend(use_bm25s: bool, tokenizer, k1=1.3, b=0.7, logger=print) -> AbstractBM25Backend:
"""
Factory function to get the best available BM25 backend.
Prefers the fast and reliable BM25SBackend, with a pure-Python fallback.
"""
if use_bm25s:
if _BM25S_AVAILABLE:
try:
be = BM25SBackend(tokenizer, k1=k1, b=b)
if logger: logger(f"[BM25] Using high-performance BM25S backend.")
return be
except Exception as e:
if logger: logger(f"[BM25] BM25S failed to initialize ({e}); falling back to DeterministicBM25.")
else:
if logger: logger(f"[BM25] bm25s library not installed; falling back to DeterministicBM25.")
if logger: logger(f"[BM25] Using pure-Python DeterministicBM25 backend.")
return DeterministicBM25Backend(tokenizer, k1=k1, b=b)