| import os, re, math, unicodedata, time, json, hashlib, importlib.util
|
| from collections import defaultdict, Counter
|
| from typing import List, Tuple, Dict, Optional
|
| import numpy as np
|
| import torch
|
| from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
|
| from sentence_transformers import SentenceTransformer
|
| import sys, pathlib
|
|
|
| HERE = pathlib.Path(__file__).resolve().parent
|
| if str(HERE) not in sys.path:
|
| sys.path.insert(0, str(HERE))
|
|
|
|
|
| BM25_K1 = 1.3
|
| BM25_B = 0.7
|
| RRF_K = 35
|
| CE_MAXLEN = 640
|
| CE_BATCH = 128
|
|
|
| TOP_BM25 = TOP_E5 = TOP_GEMMA = CE_POOL = 190
|
|
|
|
|
| WRRF_BM25_W = 1.0
|
| WRRF_E5_W = 1.2
|
| WRRF_GEMMA_W= 1.4
|
|
|
|
|
| FINAL_SCORE_BGE_WEIGHT = .07
|
|
|
|
|
| USE_CACHE = True
|
|
|
| BGE_DIR = r"models/bge-reranker-hsrc-pairwise-rrf-V1.4".strip()
|
|
|
| E5_DIR = r"models/e5-large-ft_v6".strip()
|
| E5_EVAL_CACHE_DIR = r"".strip() if USE_CACHE else r""
|
|
|
| GEMMA_DIR = r"models/multilingual-e5-large".strip()
|
| GEMMA_EVAL_CACHE_DIR = r"".strip() if USE_CACHE else r""
|
|
|
|
|
| PREFER_BF16_GEMMA = True
|
| GEMMA_MAX_TOK = 512
|
|
|
|
|
|
|
| _EVAL_SILENT = os.environ.get("EVAL_STD_MODE","").strip() == "1"
|
| def _log(msg: str):
|
| if not _EVAL_SILENT:
|
| print(msg, flush=True)
|
|
|
|
|
|
|
| try:
|
| from .text_utils import (
|
| tok_he, norm_bm25,
|
| norm_e5_query, norm_e5_passage,
|
| norm_gemma_query, norm_gemma_passage,
|
| norm_bge_query, norm_bge_passage
|
| )
|
| _log("[Init] Loaded text_utils (relative import).")
|
| except (ImportError, ModuleNotFoundError):
|
| try:
|
| from text_utils import (
|
| tok_he, norm_bm25,
|
| norm_e5_query, norm_e5_passage,
|
| norm_gemma_query, norm_gemma_passage,
|
| norm_bge_query, norm_bge_passage
|
| )
|
| _log("[Init] Loaded text_utils (sys.path import).")
|
| except (ImportError, ModuleNotFoundError):
|
| try:
|
| spec_path = HERE / "text_utils.py"
|
| if not spec_path.is_file():
|
| raise FileNotFoundError(f"{spec_path} not found.")
|
| spec = importlib.util.spec_from_file_location("text_utils", spec_path)
|
| text_utils_module = importlib.util.module_from_spec(spec)
|
| spec.loader.exec_module(text_utils_module)
|
| tok_he = text_utils_module.tok_he
|
| norm_bm25 = text_utils_module.norm_bm25
|
| norm_e5_query = text_utils_module.norm_e5_query
|
| norm_e5_passage = text_utils_module.norm_e5_passage
|
|
|
| norm_gemma_query = getattr(text_utils_module, "norm_gemma_query", text_utils_module.norm_e5_query)
|
| norm_gemma_passage = getattr(text_utils_module, "norm_gemma_passage", text_utils_module.norm_e5_passage)
|
| norm_bge_query = text_utils_module.norm_bge_query
|
| norm_bge_passage = text_utils_module.norm_bge_passage
|
| _log("[Init] Loaded text_utils (dynamic import).")
|
| except Exception:
|
| _log("[Init] `text_utils` not found. Using generic Hebrew-friendly normalizers for all components.")
|
| HEB_PREFIXES = ("ו","ה","ב","ל","כ","מ","ש")
|
| STOPWORDS = {}
|
| def _generic_norm_he(s:str)->str:
|
| if not s: return ""
|
| s=unicodedata.normalize("NFKC",s)
|
| s=re.sub(r"[\u0591-\u05BD\u05BF-\u05C7]","",s)
|
| s=(s.replace("״",'"').replace("׳","'")
|
| .replace("”",'"').replace("“",'"')
|
| .replace("–","-").replace("—","-"))
|
| return re.sub(r"\s+"," ",s).strip()
|
| norm_bm25 = norm_e5_query = norm_e5_passage = _generic_norm_he
|
| norm_gemma_query = norm_gemma_passage = _generic_norm_he
|
| norm_bge_query = norm_bge_passage = _generic_norm_he
|
| def tok_he(text: str) -> List[str]:
|
| s = norm_bm25(text)
|
| toks = re.findall(r"[A-Za-z0-9\u0590-\u05FF]+", s)
|
| out=[]
|
| for t in toks:
|
| if len(t)>3 and t[0] in HEB_PREFIXES: out.append(t[1:])
|
| out.append(t)
|
| return [t for t in out if t not in STOPWORDS]
|
|
|
|
|
| get_bm25_backend = None
|
| _HAS_BM25_BACKENDS = False
|
| try:
|
| from .bm25_backends import get_bm25_backend
|
| _HAS_BM25_BACKENDS = True
|
| _log("[Init] Loaded bm25_backends (relative import).")
|
| except (ImportError, ModuleNotFoundError):
|
| try:
|
| from bm25_backends import get_bm25_backend
|
| _HAS_BM25_BACKENDS = True
|
| _log("[Init] Loaded bm25_backends (sys.path import).")
|
| except (ImportError, ModuleNotFoundError):
|
| try:
|
| spec_path = HERE / "bm25_backends.py"
|
| if not spec_path.is_file():
|
| raise FileNotFoundError(f"{spec_path} not found.")
|
| spec = importlib.util.spec_from_file_location("bm25_backends", spec_path)
|
| bm25_module = importlib.util.module_from_spec(spec)
|
| spec.loader.exec_module(bm25_module)
|
| get_bm25_backend = bm25_module.get_bm25_backend
|
| _HAS_BM25_BACKENDS = True
|
| _log("[Init] Loaded bm25_backends (dynamic import).")
|
| except Exception as e:
|
| _log(f"[Init] Could not load bm25_backends.py ({e}). Will use built-in fallbacks.")
|
| pass
|
|
|
| class _LocalBM25SBackend:
|
| """Minimal local wrapper for bm25s when bm25_backends.py is missing."""
|
| def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7, logger=_log):
|
| import bm25s
|
| self._BM25 = bm25s.BM25
|
| self.tokenizer = tokenizer
|
| self.k1, self.b = k1, b
|
| self._bm25 = None
|
| self.doc_ids: List[str] = []
|
| self._logger = logger
|
| @property
|
| def name(self)->str:
|
| return f"LocalBM25S(k1={self.k1}, b={self.b})"
|
| def build(self, ids: List[str], texts: List[str]):
|
| t0=time.time()
|
| self.doc_ids = list(ids)
|
| tokenized = [self.tokenizer(t) for t in texts]
|
| self._bm25 = self._BM25(k1=self.k1, b=self.b)
|
| self._bm25.index(tokenized)
|
| if self._logger: self._logger(f"[{self.name}] Indexed {len(self.doc_ids):,} docs in {time.time()-t0:.2f}s")
|
| def search(self, query: str, topk: int = 300) -> List[str]:
|
| terms = self.tokenizer(query)
|
| if not terms or self._bm25 is None: return []
|
| k = min(topk, len(self.doc_ids))
|
| idxs, scores = self._bm25.retrieve([terms], k=k)
|
| idxs, scores = idxs[0], scores[0]
|
| mask = np.isfinite(scores) & (scores > 0)
|
| idxs, scores = idxs[mask], scores[mask]
|
| if idxs.size == 0: return []
|
| order = np.lexsort((idxs, -scores))
|
| idxs = idxs[order]
|
| return [self.doc_ids[int(i)] for i in idxs]
|
|
|
| class _DeterministicBM25Backend:
|
| """Embedded pure-Python deterministic BM25. Guaranteed fallback."""
|
| def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7, logger=_log):
|
| self.tokenizer=tokenizer; self.k1=k1; self.b=b
|
| self.doc_ids: List[str]=[]; self.N=0; self.avgdl=0.0
|
| self.doc_lens=None; self.vocab: Dict[str,int]={}
|
| self.postings: Dict[int,Tuple[np.ndarray,np.ndarray]]={}
|
| self.idf=None; self._logger=logger
|
| @property
|
| def name(self)->str:
|
| return f"DeterministicBM25(k1={self.k1}, b={self.b})"
|
| def build(self, ids: List[str], texts: List[str]):
|
| self.doc_ids=list(ids); self.N=len(ids)
|
| lens=np.zeros(self.N,dtype=np.int32)
|
| tmp=defaultdict(list)
|
| t0=time.time()
|
| for i, text in enumerate(texts):
|
| terms=self.tokenizer(text); lens[i]=len(terms)
|
| if not terms: continue
|
| ctr=Counter(terms)
|
| for t,tf in ctr.items():
|
| tid=self.vocab.setdefault(t, len(self.vocab))
|
| tmp[tid].append((i, tf))
|
| self.doc_lens=lens; self.avgdl=float(np.maximum(1,lens).mean())
|
| V=len(self.vocab); self.idf=np.zeros(V,dtype=np.float32)
|
| self.postings={}
|
| for tid, pairs in tmp.items():
|
| docs=np.array([d for d,_ in pairs],dtype=np.int32)
|
| tfs =np.array([tf for _,tf in pairs],dtype=np.float32)
|
| df=float(len(docs))
|
| idf=math.log((self.N-df+0.5)/(df+0.5)+1.0)
|
| self.idf[tid]=idf
|
| self.postings[tid]=(docs,tfs)
|
| if self._logger: self._logger(f"[{self.name}] Indexed {self.N:,} docs in {time.time()-t0:.2f}s")
|
| def search(self, query: str, topk: int = 300) -> List[str]:
|
| terms=self.tokenizer(query)
|
| if not terms: return []
|
| seen: Dict[int,float] = {}
|
| for t in terms:
|
| tid=self.vocab.get(t)
|
| if tid is None: continue
|
| idf=float(self.idf[tid])
|
| docs,tfs=self.postings[tid]
|
| denom=tfs + self.k1*(1-self.b + self.b*(self.doc_lens[docs]/self.avgdl))
|
| contrib = idf * (tfs*(self.k1+1)) / denom
|
| for d, c in zip(docs, contrib):
|
| seen[d]=seen.get(d,0.0)+float(c)
|
| if not seen: return []
|
| idx=np.fromiter(seen.keys(),dtype=np.int32)
|
| scs=np.fromiter(seen.values(),dtype=np.float32)
|
| k=min(topk,len(scs))
|
| order = np.lexsort((idx, -scs))
|
| order = order[:k]
|
| idx = idx[order]
|
| return [self.doc_ids[i] for i in idx]
|
|
|
| class BM25Index:
|
| """Unified BM25 wrapper. Returns List[str] of doc IDs."""
|
| def __init__(self, k1=1.3, b=0.70, logger=_log):
|
| self.k1, self.b = k1, b
|
| self.doc_ids: List[str] = []
|
| self._be = None; self._backend_name = "unset"; self._logger = logger
|
| def build(self, ids: List[str], texts_norm: List[str]):
|
| if _HAS_BM25_BACKENDS and callable(get_bm25_backend):
|
| try:
|
| self._be = get_bm25_backend(use_bm25s=True, tokenizer=tok_he, k1=self.k1, b=self.b, logger=self._logger)
|
| self._be.build(ids, texts_norm)
|
| self.doc_ids = list(self._be.doc_ids)
|
| self._backend_name = f"{self._be.name} (bm25_backends.py)"
|
| if self._logger: self._logger(f"[BM25] Using backend: {self._backend_name}")
|
| return
|
| except Exception as e:
|
| if self._logger: self._logger(f"[BM25] bm25_backends failed ({e}). Trying direct bm25s...)")
|
| try:
|
| self._be = _LocalBM25SBackend(tok_he, k1=self.k1, b=self.b, logger=self._logger)
|
| self._be.build(ids, texts_norm)
|
| self.doc_ids = list(self._be.doc_ids)
|
| self._backend_name = f"{self._be.name} (direct)"
|
| if self._logger: self._logger(f"[BM25] Using backend: {self._backend_name}")
|
| return
|
| except Exception as e:
|
| if self._logger: self._logger(f"[BM25] bm25s unavailable ({e}). Falling back to pure-Python).")
|
| self._be = _DeterministicBM25Backend(tok_he, k1=self.k1, b=self.b, logger=self._logger)
|
| self._be.build(ids, texts_norm)
|
| self.doc_ids = list(self._be.doc_ids)
|
| self._backend_name = f"{self._be.name} (embedded)"
|
| if self._logger: self._logger(f"[BM25] Using backend: {self._backend_name}")
|
| def search(self, query: str, topk: int = 200) -> List[str]:
|
| if self._be is None: return []
|
| return self._be.search(query, topk=topk)
|
|
|
|
|
| def _resolve_model_path(primary_path: str, fallback_names: List[str]) -> str:
|
| """
|
| Resolves a model path: checks primary_path, then HERE/models, HERE, CWD, CWD/models.
|
| Falls back to first fallback name (HF id/path).
|
| """
|
| if primary_path and pathlib.Path(primary_path).is_dir():
|
| return primary_path
|
| base_dirs = [HERE / "models", HERE, pathlib.Path.cwd(), pathlib.Path.cwd() / "models"]
|
| for base in base_dirs:
|
| for name in fallback_names:
|
| candidate = base / name
|
| if candidate.is_dir():
|
| return str(candidate)
|
| return fallback_names[0]
|
|
|
| def model_name_key(s: str) -> str:
|
| if not s:
|
| return ""
|
| s = s.strip().rstrip("/\\")
|
| last = re.split(r"[\\/]+", s)[-1] or s
|
| return last.lower()
|
|
|
|
|
| class E5Embedder:
|
| def __init__(self, device=None):
|
| fallback_names = ["e5-large-ft_v4","multilingual-e5-large"]
|
| all_fallbacks = [pathlib.Path(E5_DIR).name] + fallback_names if E5_DIR else fallback_names
|
| self.model_path = _resolve_model_path(E5_DIR, all_fallbacks)
|
| self.model_name = model_name_key(self.model_path)
|
| self.device=device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| _log(f"[E5] Loading encoder from: {self.model_path} (device={self.device})")
|
| self.tok=AutoTokenizer.from_pretrained(self.model_path)
|
| self.mdl=AutoModel.from_pretrained(self.model_path, torch_dtype=torch.bfloat16 if self.device=="cuda" else None).to(self.device)
|
| self.mdl.eval()
|
| @torch.inference_mode()
|
| def encode(self, texts: List[str], is_query=False, batch=64, progress_desc="E5 encode"):
|
|
|
| pref="query: " if is_query else "passage: "
|
|
|
| out=[]
|
| n=len(texts)
|
| if n==0: return np.zeros((0,768), dtype=np.float32)
|
| total_batches = (n + batch - 1)//batch
|
| t0=time.time()
|
| for bi in range(total_batches):
|
| i = bi*batch
|
| chunk = texts[i:i+batch]
|
| enc=self.tok([pref+t.strip() for t in chunk], padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device)
|
| hs=self.mdl(**enc).last_hidden_state
|
| mask=enc["attention_mask"].unsqueeze(-1).expand(hs.size()).float()
|
| embs=(hs*mask).sum(1)/mask.sum(1).clamp(min=1e-9)
|
| embs=torch.nn.functional.normalize(embs, p=2, dim=1)
|
| out.append(embs.detach().cpu().to(dtype=torch.float32))
|
| if not _EVAL_SILENT:
|
| if (bi+1)%50==0 or bi==0 or (bi+1)==total_batches:
|
| pct = 100.0*(bi+1)/total_batches
|
| elapsed = time.time()-t0
|
| ips = (i+len(chunk))/max(elapsed,1e-6)
|
| print(f"[{progress_desc}] batch {bi+1}/{total_batches} ({pct:.1f}%) ~{ips:.0f} items/s")
|
| del enc, hs, embs
|
| if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| return torch.cat(out, dim=0).numpy()
|
|
|
|
|
| class GemmaEmbedder:
|
| """
|
| Uses SentenceTransformer('google/embeddinggemma-300m'), BF16 if available.
|
| Returns L2-normalized 768-dim numpy arrays.
|
| No manual prompt prefixing; let SentenceTransformers handle prompting.
|
| """
|
| def __init__(self, device=None):
|
| fallback_names = ["google/embeddinggemma-300m","embeddinggemma-300m"]
|
| all_fallbacks = [pathlib.Path(GEMMA_DIR).name] + fallback_names if GEMMA_DIR else fallback_names
|
| self.model_path = _resolve_model_path(GEMMA_DIR, all_fallbacks)
|
| self.model_name = model_name_key(self.model_path)
|
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| use_bf16 = bool(PREFER_BF16_GEMMA)
|
| if self.device == "cuda":
|
| try:
|
| use_bf16 = use_bf16 and torch.cuda.is_bf16_supported()
|
| except Exception:
|
| major, _ = torch.cuda.get_device_capability()
|
| use_bf16 = use_bf16 and (major >= 8)
|
| dtype = torch.bfloat16 if use_bf16 else torch.float16
|
| _log(f"[Gemma] Loading encoder from: {self.model_path} (device={self.device}, dtype={'bf16' if use_bf16 else 'fp16'})")
|
| self.mdl = SentenceTransformer(
|
| self.model_path,
|
| device=self.device,
|
| model_kwargs={"torch_dtype": dtype},
|
| )
|
|
|
| try:
|
| self.mdl.max_seq_length = int(GEMMA_MAX_TOK)
|
| except Exception:
|
| pass
|
| self.dim = 768
|
| self.mdl.eval()
|
|
|
| @torch.inference_mode()
|
| def encode(self, texts: List[str], is_query=False, batch=64, progress_desc="Gemma encode", max_length: Optional[int]=None):
|
| if not texts:
|
| return np.zeros((0, self.dim), dtype=np.float32)
|
|
|
|
|
| old_len = getattr(self.mdl, "max_seq_length", None)
|
| if isinstance(max_length, int) and max_length > 0:
|
| try:
|
| self.mdl.max_seq_length = max_length
|
| except Exception:
|
| pass
|
|
|
| show = not _EVAL_SILENT
|
|
|
|
|
| try:
|
| if is_query and hasattr(self.mdl, "encode_query"):
|
| embs = self.mdl.encode_query(
|
| texts, batch_size=batch, convert_to_numpy=True,
|
| normalize_embeddings=True, show_progress_bar=show
|
| )
|
| elif (not is_query) and hasattr(self.mdl, "encode_document"):
|
| embs = self.mdl.encode_document(
|
| texts, batch_size=batch, convert_to_numpy=True,
|
| normalize_embeddings=True, show_progress_bar=show
|
| )
|
| else:
|
|
|
| prompt = (f"{'query: ' if is_query else 'passage: '}")
|
| try:
|
| embs = self.mdl.encode(
|
| texts, batch_size=batch, convert_to_numpy=True,
|
| normalize_embeddings=True, show_progress_bar=show,
|
| prompt=prompt
|
| )
|
| except TypeError:
|
|
|
| embs = self.mdl.encode(
|
| texts, batch_size=batch, convert_to_numpy=True,
|
| normalize_embeddings=True, show_progress_bar=show
|
| )
|
| finally:
|
| if old_len is not None:
|
| try: self.mdl.max_seq_length = old_len
|
| except Exception: pass
|
|
|
| embs = np.asarray(embs)
|
| if embs.ndim == 1:
|
| embs = embs[None, :]
|
| return embs.astype(np.float32)
|
|
|
|
|
| class BGEReranker:
|
| def __init__(self, device=None):
|
| fallback_names = ["bge-reranker-hsrc-pairwise-rrf-V1.4","bge-v2-m3","bge-m3"]
|
| all_fallbacks = [pathlib.Path(BGE_DIR).name] + fallback_names if BGE_DIR else fallback_names
|
| self.model_path = _resolve_model_path(BGE_DIR, all_fallbacks)
|
| self.device=device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| _log(f"[BGE] Loading reranker from: {self.model_path} (device={self.device})")
|
| self.tok=AutoTokenizer.from_pretrained(self.model_path)
|
| self.mdl=AutoModelForSequenceClassification.from_pretrained(
|
| self.model_path, torch_dtype=torch.float16 if self.device=="cuda" else None, trust_remote_code=True
|
| ).to(self.device)
|
| self.mdl.eval()
|
| @torch.inference_mode()
|
| def score_pairs(self, q: str, passages: List[str], batch=32, max_len=512) -> List[float]:
|
| out=[]
|
| for i in range(0,len(passages), batch):
|
| enc=self.tok([q]*len(passages[i:i+batch]), passages[i:i+batch],
|
| truncation="only_second", max_length=max_len, padding=True, return_tensors="pt").to(self.device)
|
| logits=self.mdl(**enc).logits
|
| if logits.ndim==1: s=logits
|
| elif logits.shape[1]==1: s=logits.squeeze(-1)
|
| else: s=logits[:,1]
|
| out += s.detach().float().cpu().tolist()
|
| del enc, logits
|
| return [float(x) for x in out]
|
|
|
|
|
| class HybridSearcher:
|
| """
|
| Stage-1 retrieval: WRRF(BM25, E5, Gemma) → candidate ids + WRRF scores.
|
| Then stage-2 reranking is done outside in predict().
|
| """
|
| def __init__(self, bm25: BM25Index,
|
| e5: E5Embedder, e5_corpus: np.ndarray,
|
| gemma: GemmaEmbedder, gemma_corpus: np.ndarray,
|
| id2text: Dict[str,str], id2norm: Dict[str,str]):
|
| self.bm25=bm25
|
| self.e5=e5; self.e5_corpus=e5_corpus
|
| self.gemma=gemma; self.gemma_corpus=gemma_corpus
|
| self.id2text=id2text; self.id2norm=id2norm
|
| self._last_q: Optional[str] = None
|
| self._last_fused: List[Tuple[str, float]] = []
|
|
|
| def _wrrf_fuse3(self, bm_ids: List[str], e5_ids: List[str], gm_ids: List[str], k=60,
|
| w_bm25=1.0, w_e5=1.0, w_gm=1.0) -> List[Tuple[str, float]]:
|
| rankA={pid:i for i,pid in enumerate(bm_ids)}
|
| rankB={pid:i for i,pid in enumerate(e5_ids)}
|
| rankC={pid:i for i,pid in enumerate(gm_ids)}
|
| scores=defaultdict(float)
|
| for pid, r in rankA.items(): scores[pid]+=w_bm25*(1.0/(k+r+1))
|
| for pid, r in rankB.items(): scores[pid]+=w_e5 *(1.0/(k+r+1))
|
| for pid, r in rankC.items(): scores[pid]+=w_gm *(1.0/(k+r+1))
|
| return sorted(scores.items(), key=lambda x:-x[1])
|
|
|
| def search(self, query: str, topk: int=200) -> List[Tuple[str, float]]:
|
| if self._last_q == query and self._last_fused:
|
| return self._last_fused[:topk]
|
|
|
|
|
| bm_ids = self.bm25.search(query, topk=TOP_BM25)
|
|
|
|
|
| q_norm_e5 = norm_e5_query(query)
|
| qe = self.e5.encode([q_norm_e5], is_query=True, batch=1, progress_desc="E5 query")[0]
|
| sims_e5 = (self.e5_corpus @ qe)
|
| k2 = min(TOP_E5, len(sims_e5))
|
| top_idx_e5 = np.argpartition(-sims_e5, k2-1)[:k2]
|
| top_idx_e5 = top_idx_e5[np.argsort(-sims_e5[top_idx_e5])]
|
| e5_ids = [self.bm25.doc_ids[i] for i in top_idx_e5]
|
|
|
|
|
| q_norm_gm = norm_gemma_query(query)
|
| qg = self.gemma.encode([q_norm_gm], is_query=True, batch=1, progress_desc="Gemma query", max_length=GEMMA_MAX_TOK)[0]
|
| sims_gm = (self.gemma_corpus @ qg)
|
| k3 = min(TOP_GEMMA, len(sims_gm))
|
| top_idx_gm = np.argpartition(-sims_gm, k3-1)[:k3]
|
| top_idx_gm = top_idx_gm[np.argsort(-sims_gm[top_idx_gm])]
|
| gm_ids = [self.bm25.doc_ids[i] for i in top_idx_gm]
|
|
|
| fused_with_scores = self._wrrf_fuse3(
|
| bm_ids, e5_ids, gm_ids, k=RRF_K,
|
| w_bm25=WRRF_BM25_W, w_e5=WRRF_E5_W, w_gm=WRRF_GEMMA_W
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| out = fused_with_scores[:topk]
|
|
|
| self._last_q = query
|
| self._last_fused = out[:]
|
| return out
|
|
|
|
|
| _STATE = {}
|
|
|
|
|
| def _sha1_ids(ids: List[str]) -> str:
|
| h = hashlib.sha1()
|
| for pid in ids:
|
| h.update(pid.encode("utf-8")); h.update(b"\n")
|
| return h.hexdigest()
|
|
|
| def _normalize_min_max(scores: List[float]) -> List[float]:
|
| """Scales a list of scores to the [0, 1] range."""
|
| if not scores or len(scores) < 2:
|
| return [0.5] * len(scores)
|
| min_s, max_s = min(scores), max(scores)
|
| delta = max_s - min_s
|
| if delta < 1e-9:
|
| return [0.5] * len(scores)
|
| return [(s - min_s) / delta for s in scores]
|
|
|
|
|
| def preprocess(corpus_dict: Dict[str, Dict]) -> Dict:
|
| ids, texts = [], []
|
| bm25_norms = []
|
|
|
|
|
| e5_passage_norms = []
|
| gm_passage_norms = []
|
|
|
| for pid,obj in corpus_dict.items():
|
| t = obj.get("passage") or obj.get("text") or ""
|
| pid = str(pid)
|
| ids.append(pid)
|
| texts.append(t)
|
| bm25_norms.append(norm_bm25(t))
|
| e5_passage_norms.append(norm_e5_passage(t))
|
| gm_passage_norms.append(norm_gemma_passage(t))
|
|
|
| _log("="*60)
|
| _log(f"PREPROCESS: Building BM25 + E5 + Gemma embeddings + loading BGE")
|
| _log("="*60)
|
|
|
|
|
| bm25 = BM25Index(k1=BM25_K1, b=BM25_B, logger=_log)
|
| bm25.build(ids, bm25_norms)
|
|
|
|
|
| e5 = E5Embedder()
|
| e5_mat = None
|
| cache_note_e5 = None
|
| if E5_EVAL_CACHE_DIR:
|
| os.makedirs(E5_EVAL_CACHE_DIR, exist_ok=True)
|
| meta_p = os.path.join(E5_EVAL_CACHE_DIR, "e5_meta.json")
|
| npy_p = os.path.join(E5_EVAL_CACHE_DIR, "e5_corpus.npy")
|
| sha = _sha1_ids(ids)
|
| if os.path.isfile(meta_p) and os.path.isfile(npy_p):
|
| try:
|
| with open(meta_p,"r",encoding="utf-8") as f: m=json.load(f)
|
| if m.get("sha1_ids")==sha and model_name_key(m.get("model_path",""))==e5.model_name and m.get("num_docs")==len(ids):
|
| _log(f"[E5] Loading cached corpus embeddings from {npy_p}")
|
| e5_mat = np.load(npy_p, mmap_mode=None)
|
| cache_note_e5 = "loaded"
|
| except Exception as e: _log(f"[E5] Cache read failed: {e} — recomputing.")
|
| if e5_mat is None:
|
| _log("[E5] Computing corpus embeddings...")
|
| t0=time.time()
|
| e5_mat = e5.encode(e5_passage_norms, is_query=False, batch=64, progress_desc="E5 corpus")
|
| _log(f"[E5] Done in {time.time()-t0:.1f}s — shape={e5_mat.shape}")
|
| if E5_EVAL_CACHE_DIR:
|
| try:
|
| np.save(os.path.join(E5_EVAL_CACHE_DIR,"e5_corpus.npy"), e5_mat)
|
| meta = {"sha1_ids": _sha1_ids(ids), "num_docs": len(ids), "model_path": e5.model_path, "dim": int(e5_mat.shape[1]), "created": time.time()}
|
| with open(os.path.join(E5_EVAL_CACHE_DIR,"e5_meta.json"),"w",encoding="utf-8") as f: json.dump(meta,f,ensure_ascii=False, indent=2)
|
| cache_note_e5 = "saved"
|
| _log(f"[E5] Saved cache to {E5_EVAL_CACHE_DIR}")
|
| except Exception as e: _log(f"[E5] Cache save failed: {e}")
|
|
|
|
|
| gemma = GemmaEmbedder()
|
| gemma_mat = None
|
| cache_note_gm = None
|
| if GEMMA_EVAL_CACHE_DIR:
|
| os.makedirs(GEMMA_EVAL_CACHE_DIR, exist_ok=True)
|
| meta_p_gm = os.path.join(GEMMA_EVAL_CACHE_DIR, "gemma_meta.json")
|
| npy_p_gm = os.path.join(GEMMA_EVAL_CACHE_DIR, "gemma_corpus.npy")
|
| sha = _sha1_ids(ids)
|
| if os.path.isfile(meta_p_gm) and os.path.isfile(npy_p_gm):
|
| try:
|
| with open(meta_p_gm,"r",encoding="utf-8") as f: m=json.load(f)
|
| if m.get("sha1_ids")==sha and model_name_key(m.get("model_path",""))==gemma.model_name and m.get("num_docs")==len(ids):
|
| _log(f"[Gemma] Loading cached corpus embeddings from {npy_p_gm}")
|
| gemma_mat = np.load(npy_p_gm, mmap_mode=None)
|
| cache_note_gm = "loaded"
|
| except Exception as e: _log(f"[Gemma] Cache read failed: {e} — recomputing.")
|
| if gemma_mat is None:
|
| _log("[Gemma] Computing corpus embeddings...")
|
| t0=time.time()
|
| gemma_mat = gemma.encode(gm_passage_norms, is_query=False, batch=64, progress_desc="Gemma corpus", max_length=GEMMA_MAX_TOK)
|
| _log(f"[Gemma] Done in {time.time()-t0:.1f}s — shape={gemma_mat.shape}")
|
| if GEMMA_EVAL_CACHE_DIR:
|
| try:
|
| np.save(os.path.join(GEMMA_EVAL_CACHE_DIR,"gemma_corpus.npy"), gemma_mat)
|
| meta_gm = {"sha1_ids": _sha1_ids(ids), "num_docs": len(ids), "model_path": gemma.model_path, "dim": int(gemma_mat.shape[1]), "created": time.time()}
|
| with open(os.path.join(GEMMA_EVAL_CACHE_DIR,"gemma_meta.json"),"w",encoding="utf-8") as f: json.dump(meta_gm,f,ensure_ascii=False, indent=2)
|
| cache_note_gm = "saved"
|
| _log(f"[Gemma] Saved cache to {GEMMA_EVAL_CACHE_DIR}")
|
| except Exception as e: _log(f"[Gemma] Cache save failed: {e}")
|
|
|
|
|
| rr = BGEReranker()
|
|
|
| id2text = dict(zip(ids,texts))
|
| id2norm = dict(zip(ids,bm25_norms))
|
|
|
| hybrid = HybridSearcher(bm25, e5, e5_mat, gemma, gemma_mat, id2text, id2norm)
|
| _STATE.update({
|
| "bm25": bm25, "id2text": id2text, "id2norm": id2norm,
|
| "e5": e5, "e5_corpus": e5_mat,
|
| "gemma": gemma, "gemma_corpus": gemma_mat,
|
| "reranker": rr, "hybrid": hybrid
|
| })
|
|
|
| reranker_params = {
|
| "CE_POOL": CE_POOL, "CE_MAXLEN": CE_MAXLEN, "CE_BATCH": CE_BATCH,
|
| "FINAL_SCORE_BGE_WEIGHT": FINAL_SCORE_BGE_WEIGHT
|
| }
|
|
|
| meta = {
|
| "stage1_name": "WRRF(BM25, E5, Gemma)",
|
| "stage1_params": {
|
| "TOP_BM25": TOP_BM25, "TOP_E5": TOP_E5, "TOP_GEMMA": TOP_GEMMA, "RRF_K": RRF_K,
|
| "WRRF_WEIGHTS": {"bm25": WRRF_BM25_W, "e5": WRRF_E5_W, "gemma": WRRF_GEMMA_W}
|
| },
|
| "reranker_name": "BGE + Hybrid Fusion (Conditional Boost)",
|
| "reranker_params": reranker_params,
|
| "candidate_pool_cap": CE_POOL,
|
| "stage1_search_key": "bm25",
|
| "bm25_backend": getattr(bm25, "_backend_name", "unknown"),
|
| "e5_model_path": e5.model_path,
|
| "gemma_model_path": gemma.model_path,
|
| "bge_model_path": rr.model_path,
|
| "cache_dir_e5": E5_EVAL_CACHE_DIR or None,
|
| "cache_dir_gemma": GEMMA_EVAL_CACHE_DIR or None,
|
| "e5_cache": cache_note_e5 or ("unused" if not E5_EVAL_CACHE_DIR else "miss"),
|
| "gemma_cache": cache_note_gm or ("unused" if not GEMMA_EVAL_CACHE_DIR else "miss"),
|
| }
|
|
|
| _log("✓ PREPROCESS complete.")
|
| return {
|
| "bm25": hybrid, "id2text": id2text, "id2norm": id2norm,
|
| "reranker": rr, "num_documents": len(ids), "_eval": meta
|
| }
|
|
|
| def predict(query: Dict, pre: Dict):
|
| q = query.get("query","")
|
| if not q: return []
|
| hyb = _STATE.get("hybrid") or pre["bm25"]
|
| rr = _STATE.get("reranker") or pre["reranker"]
|
| id2text = _STATE.get("id2text") or pre["id2text"]
|
|
|
|
|
| cand_id_scores = hyb.search(q, topk=CE_POOL)
|
| if not cand_id_scores: return []
|
| cand_ids, rrf_scores = zip(*cand_id_scores)
|
| passages = [id2text[pid] for pid in cand_ids]
|
|
|
|
|
| q_norm_bge = norm_bge_query(q)
|
| passages_norm_bge = [norm_bge_passage(p) for p in passages]
|
| bge_scores = rr.score_pairs(q_norm_bge, passages_norm_bge, batch=CE_BATCH, max_len=CE_MAXLEN)
|
|
|
|
|
| norm_bge = _normalize_min_max(bge_scores)
|
| norm_rrf = _normalize_min_max(list(rrf_scores))
|
| final_scores = []
|
| w_rrf = 1.0 - FINAL_SCORE_BGE_WEIGHT
|
| for bge_score, rrf_score in zip(norm_bge, norm_rrf):
|
| boost = w_rrf * rrf_score * (1.0 - bge_score)
|
| final_scores.append(bge_score + boost)
|
|
|
|
|
| out = [{"paragraph_uuid": pid, "score": float(s)}
|
| for pid, s in sorted(zip(cand_ids, final_scores), key=lambda x: -x[1])]
|
| return out
|
|
|