"""Retrieval-methods comparison demo (BM25 vs Dense vs Hybrid/RRF, + optional reranking). A teaching app for EMBL-EBI researchers. Reads pre-built artifacts from ./data/ (made by build_index.py) and only embeds the user's LIVE query at runtime. Read top-to-bottom: artifacts -> embed query (GPU) -> filter -> 3 rankings -> RRF -> UMAP plot | optional: cross-encoder rerank (GPU) Two embedding models are shipped (general MiniLM + biomedical PubMedBERT); a radio switches which one Dense/Hybrid/plot use (BM25 is lexical, so it's unaffected). ZeroGPU note: every model is loaded on CPU at import. GPU is touched ONLY inside the @spaces.GPU functions. Do not move a model to CUDA anywhere else. """ import html import json import os import time import gradio as gr import joblib import numpy as np import pandas as pd import plotly.graph_objects as go import spaces import torch from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder, SentenceTransformer import config from text_utils import tokenize # ==================================================================================== # Load pre-built artifacts (fast; no embedding, no network) # ==================================================================================== D = config.DATA_DIR # Shared, model-independent artifacts. META = pd.read_parquet(os.path.join(D, "metadata.parquet")) with open(os.path.join(D, "bm25_tokens.json")) as f: BM25 = BM25Okapi(json.load(f)) TITLES = META["title"].tolist() ABSTRACTS = META["abstract"].tolist() YEARS = META["year"].fillna(0).astype(int).to_numpy() JOURNALS = META["journal"].astype(str).to_numpy() YEAR_MIN = int(YEARS[YEARS > 0].min()) YEAR_MAX = int(YEARS.max()) # 1300+ distinct journals — show only the most common ones in the dropdown. TOP_JOURNALS = META["journal"].value_counts().head(30).index.tolist() # Per-model artifacts + models, keyed by model_key (see config.EMBEDDING_MODELS). EMB, COORDS, REDUCER, MODELS = {}, {}, {}, {} for _key, (_label, _name) in config.EMBEDDING_MODELS.items(): EMB[_key] = np.load(os.path.join(D, f"embeddings_{_key}.npy")) # (N, dim), L2-normalised COORDS[_key] = np.load(os.path.join(D, f"umap_coords_{_key}.npy")) # (N, 2) REDUCER[_key] = joblib.load(os.path.join(D, f"umap_{_key}.joblib")) # .transform new points MODELS[_key] = SentenceTransformer(_name, device="cpu") # CPU only at import! _gd = getattr(MODELS[_key], "get_embedding_dimension", None) or MODELS[_key].get_sentence_embedding_dimension assert _gd() == EMB[_key].shape[1], f"Model/embedding dim mismatch for '{_key}' — rebuild?" # Cross-encoder reranker — also CPU at import; moved to GPU inside @spaces.GPU. RERANKER = CrossEncoder(config.RERANKER_MODEL, device="cpu") # UI label <-> internal key. MODEL_LABELS = {k: lbl for k, (lbl, _) in config.EMBEDDING_MODELS.items()} LABEL_TO_KEY = {lbl: k for k, lbl in MODEL_LABELS.items()} DEFAULT_LABEL = MODEL_LABELS[config.DEFAULT_MODEL_KEY] # Per-method plot styling. Sizes/widths are tiered so that a doc retrieved by several # methods nests into concentric rings instead of one trace hiding another. METHOD_COLORS = {"BM25": "#ff7f0e", "Dense": "#1f77b4", "Hybrid": "#2ca02c", "Reranked": "#111111"} METHOD_SIZES = {"BM25": 20, "Dense": 15, "Hybrid": 10, "Reranked": 13} METHOD_WIDTHS = {"BM25": 4, "Dense": 2.5, "Hybrid": 1, "Reranked": 2.5} # ==================================================================================== # GPU functions — the ONLY place CUDA is touched # ==================================================================================== def _encode(model_key: str, text: str) -> np.ndarray: """Embed one string with the chosen model -> (dim,) float32, L2-normalised.""" device = "cuda" if torch.cuda.is_available() else "cpu" # safe: only under @spaces.GPU model = MODELS[model_key] model.to(device) vec = model.encode([text], normalize_embeddings=True, convert_to_numpy=True, device=device) return vec[0].astype(np.float32) @spaces.GPU def embed_query(model_key: str, text: str) -> np.ndarray: """GPU-backed query embedding for retrieval.""" return _encode(model_key, text) @spaces.GPU def embed_text(model_key: str, text: str) -> np.ndarray: """GPU-backed embedding for the 'embed your own text' feature.""" return _encode(model_key, text) @spaces.GPU def rerank_scores(query: str, texts: list[str]) -> np.ndarray: """Cross-encoder relevance scores for (query, doc) pairs. One transformer forward pass PER candidate — this is the expensive, can't-precompute step.""" device = "cuda" if torch.cuda.is_available() else "cpu" RERANKER.model.to(device) # CrossEncoder.device follows the underlying model pairs = [(query, t) for t in texts] scores = RERANKER.predict(pairs, batch_size=16, convert_to_numpy=True, show_progress_bar=False) return np.asarray(scores, dtype=np.float32) # ==================================================================================== # Retrieval (pure CPU / numpy) # ==================================================================================== def candidate_indices(year_lo: int, year_hi: int, journal: str) -> np.ndarray: """Apply the metadata filter BEFORE retrieval. Returns indices of allowed docs. This pre-filtering of the candidate pool is the distinctly vector-DB feature: every method then searches only within these documents. """ mask = (YEARS >= int(year_lo)) & (YEARS <= int(year_hi)) if journal and journal != "All journals": mask &= JOURNALS == journal return np.flatnonzero(mask) def top_k(scores: np.ndarray, cand: np.ndarray, k: int) -> list[tuple[int, float]]: """Top-k (doc_index, score) among candidates, ranked by score descending.""" cand_scores = scores[cand] order = np.argsort(-cand_scores)[:k] return [(int(cand[i]), float(cand_scores[i])) for i in order] def rrf_fuse(bm25: np.ndarray, dense: np.ndarray, cand: np.ndarray, k: int): """Reciprocal Rank Fusion over the FULL candidate rankings, then take top-k. Rank every candidate by each method, sum 1/(RRF_K + rank) across methods, and only then truncate. (Truncating each method first would drop docs ranked low by one method but high by the other.) No score normalisation. """ def ranks(scores): order = np.argsort(-scores[cand]) # candidate positions, best first r = np.empty(len(cand), dtype=int) r[order] = np.arange(len(cand)) # 0-based rank per candidate position return r rb, rd = ranks(bm25), ranks(dense) fused = 1.0 / (config.RRF_K + rb) + 1.0 / (config.RRF_K + rd) order = np.argsort(-fused)[:k] return [(int(cand[i]), float(fused[i])) for i in order] def format_results(method: str, results: list[tuple[int, float]], score_label: str) -> str: """Render a ranked list as HTML. Each result is a
— click the title to expand the FULL abstract. This list is the SOURCE OF TRUTH for retrieval. """ if not results: return f"

{method}

No results.

" parts = [f"

{method}

"] for rank, (doc, score) in enumerate(results, start=1): title = html.escape(TITLES[doc]) # abstracts contain <, >, & — must escape abstract = html.escape(ABSTRACTS[doc]) journal = html.escape(str(JOURNALS[doc])) parts.append( "
" f"{rank}. {title}
" f"{score_label}={score:.3f} · {int(YEARS[doc])} · {journal}" "
" f"

{abstract}

" "
" ) return "".join(parts) def make_plot(model_key, query_coord=None, retrieved=None, extra_point=None, cand=None) -> go.Figure: """UMAP scatter for the chosen model, plus query point + connector lines to hits. `cand` (optional) = indices that passed the metadata filter; when given, only those points form the grey background, so the plot reflects the filter. """ coords = COORDS[model_key] fig = go.Figure() bg = np.arange(len(coords)) if cand is None else np.asarray(cand) fig.add_trace( go.Scatter( x=coords[bg, 0], y=coords[bg, 1], mode="markers", marker=dict(size=4, color="lightgrey"), text=[TITLES[i] for i in bg], hoverinfo="text", name=f"corpus ({len(bg)})", ) ) if retrieved and query_coord is not None: for method, results in retrieved.items(): color = METHOD_COLORS[method] xs, ys = [], [] for doc, _ in results: # connector lines query -> each hit xs += [query_coord[0], coords[doc, 0], None] ys += [query_coord[1], coords[doc, 1], None] fig.add_trace( go.Scatter(x=xs, y=ys, mode="lines", line=dict(color=color, width=METHOD_WIDTHS[method]), opacity=0.5, name=f"{method} links", hoverinfo="skip") ) fig.add_trace( go.Scatter( x=[coords[d, 0] for d, _ in results], y=[coords[d, 1] for d, _ in results], mode="markers", marker=dict(size=METHOD_SIZES[method], color=color, symbol="circle-open", line=dict(width=2)), text=[TITLES[d] for d, _ in results], hoverinfo="text", name=f"{method} hits", ) ) if query_coord is not None: fig.add_trace( go.Scatter(x=[query_coord[0]], y=[query_coord[1]], mode="markers", marker=dict(size=18, color="red", symbol="star"), text=["your query"], hoverinfo="text", name="query") ) if extra_point is not None: coord, label = extra_point fig.add_trace( go.Scatter(x=[coord[0]], y=[coord[1]], mode="markers", marker=dict(size=15, color="purple", symbol="diamond"), text=[label], hoverinfo="text", name="your text") ) fig.update_layout( margin=dict(l=10, r=10, t=30, b=10), height=520, xaxis_title="UMAP-1", yaxis_title="UMAP-2", legend=dict(orientation="h", yanchor="bottom", y=1.0), ) return fig # ==================================================================================== # Gradio handlers # ==================================================================================== RERANK_HINT = "

Press Rerank to reorder the hybrid shortlist with the cross-encoder.

" def run_search(query, k, year_lo, year_hi, journal, model_label): """Stage 1: BM25 + Dense + Hybrid. Fast. Stashes a shortlist for optional reranking.""" key = LABEL_TO_KEY[model_label] query = (query or "").strip() if not query: return "Enter a query.", "", "", "", make_plot(key), "—", None, "" cand = candidate_indices(year_lo, year_hi, journal) info = f"**Candidate pool: {len(cand)} / {len(META)} abstracts** after filtering." if len(cand) == 0: return "No documents match the filter.", "", "", "", make_plot(key, cand=cand), info, None, "" k = int(k) qvec = embed_query(key, query) # <-- GPU work (live query only) dense_scores = EMB[key] @ qvec # exact cosine (vectors are normalised) bm25_scores = np.asarray(BM25.get_scores(tokenize(query))) bm25_top = top_k(bm25_scores, cand, k) dense_top = top_k(dense_scores, cand, k) hybrid_top = rrf_fuse(bm25_scores, dense_scores, cand, k) # Shortlist for the reranker = hybrid's top-N (first-stage retrieval feeds stage 2). shortlist = [d for d, _ in rrf_fuse(bm25_scores, dense_scores, cand, config.RERANK_CANDIDATES)] query_coord = REDUCER[key].transform(qvec.reshape(1, -1))[0] fig = make_plot(key, query_coord, {"BM25": bm25_top, "Dense": dense_top, "Hybrid": hybrid_top}, cand=cand) state = { "key": key, "query": query, "query_coord": [float(query_coord[0]), float(query_coord[1])], "cand": cand.tolist(), "shortlist": shortlist, "hybrid_top": [[int(d), float(s)] for d, s in hybrid_top], } return ( format_results("BM25", bm25_top, "score"), format_results("Dense", dense_top, "cosine"), format_results("Hybrid", hybrid_top, "RRF"), RERANK_HINT, fig, info, state, "", ) def run_rerank(state, k): """Stage 2: cross-encoder reranks the hybrid shortlist. Slow — and we time it.""" if not state: return "

Run a search first.

", gr.update(), "" key, query, shortlist = state["key"], state["query"], state["shortlist"] texts = [f"{TITLES[i]}. {ABSTRACTS[i]}" for i in shortlist] t0 = time.time() scores = rerank_scores(query, texts) # <-- GPU work, one pass per candidate dt = time.time() - t0 order = np.argsort(-scores)[: int(k)] reranked = [(int(shortlist[i]), float(scores[i])) for i in order] hybrid_top = [(int(d), float(s)) for d, s in state["hybrid_top"]] fig = make_plot(key, state["query_coord"], {"Hybrid": hybrid_top, "Reranked": reranked}, cand=np.asarray(state["cand"])) timer = ( f"⏱️ Reranked **{len(shortlist)}** candidates with the cross-encoder in " f"**{dt:.2f}s**. It ran a transformer on every (query, document) pair at query " f"time — contrast the instant BM25/Dense/Hybrid. The plot shows Hybrid (green) " f"vs Reranked (black)." ) return format_results("Reranked", reranked, "cross-enc"), fig, timer def run_embed_own(text, model_label): key = LABEL_TO_KEY[model_label] text = (text or "").strip() if not text: return make_plot(key) vec = embed_text(key, text) # <-- GPU work coord = REDUCER[key].transform(vec.reshape(1, -1))[0] return make_plot(key, extra_point=(coord, text[:80])) # ==================================================================================== # Example queries — tuned to the cardiovascular slice of this microRNA/disease corpus. # Each is chosen (and verified against the corpus) to make ONE method visibly win. # ==================================================================================== EXACT_ID_QUERY = "miR-208a" # BM25 wins: exact identifier token, matched verbatim. # Dense wins: plain-English paraphrase using none of the corpus jargon. PARAPHRASE_QUERY = "small RNA molecules that worsen scarring after a heart attack" ACRONYM_QUERY = "AMI" # BM25 wins: bare acronym is a lexical token; dense has little to grip. RARE_TERM_QUERY = "miR-499 cardiomyocyte apoptosis" # BM25 pins the specific molecule. # Hybrid wins: broad query where lexical + semantic each surface good-but-different papers. BROAD_CONCEPT_QUERY = "circulating microRNA biomarkers for cardiovascular disease" EXAMPLES = [ ("Exact ID (BM25)", EXACT_ID_QUERY), ("Paraphrase (Dense)", PARAPHRASE_QUERY), ("Acronym (BM25)", ACRONYM_QUERY), ("Specific molecule (BM25)", RARE_TERM_QUERY), ("Broad concept (Hybrid)", BROAD_CONCEPT_QUERY), ] PLOT_CAVEAT = ( "⚠️ **The 2D UMAP projection distorts true distances.** UMAP preserves rough local " "neighbourhoods but warps global distances, and the live query is placed by an " "*approximate* out-of-sample fit — so its position (and the connector lines) are only " "indicative. Retrieved points may *not* be the visually-closest dots. " "**The ranked lists above are authoritative**; the plot is only for intuition." ) # ==================================================================================== # UI # ==================================================================================== with gr.Blocks(title="RAG retrieval: BM25 vs Dense vs Hybrid vs Rerank") as demo: gr.Markdown( "# 🔍 Retrieval methods, side by side\n" "Compare **BM25** (lexical), **Dense** (vector cosine), **Hybrid** (Reciprocal " "Rank Fusion), and an optional **cross-encoder rerank** over a corpus of " f"{len(META):,} Europe PMC abstracts on microRNA & disease." ) with gr.Row(): query = gr.Textbox(label="Query", placeholder="Type a search query…", scale=4) k = gr.Slider(1, 10, value=5, step=1, label="Top-k", scale=1) with gr.Row(): for label, text in EXAMPLES: btn = gr.Button(label, size="sm") btn.click(lambda t=text: t, outputs=query) model_radio = gr.Radio( choices=list(LABEL_TO_KEY.keys()), value=DEFAULT_LABEL, label="Embedding model — used by Dense, Hybrid and the plot (BM25 is lexical, so it never changes)", ) with gr.Group(): gr.Markdown("**Metadata filter** — restricts the candidate pool *before* retrieval (all methods).") with gr.Row(): year_lo = gr.Slider(YEAR_MIN, YEAR_MAX, value=YEAR_MIN, step=1, label="Year from") year_hi = gr.Slider(YEAR_MIN, YEAR_MAX, value=YEAR_MAX, step=1, label="Year to") journal = gr.Dropdown(["All journals"] + TOP_JOURNALS, value="All journals", label="Journal") search_btn = gr.Button("Search", variant="primary") filter_info = gr.Markdown("—") gr.Markdown("_Click any result to expand its full abstract._") with gr.Row(): bm25_out = gr.HTML(label="BM25") dense_out = gr.HTML(label="Dense") hybrid_out = gr.HTML(label="Hybrid") reranked_out = gr.HTML(label="Reranked") with gr.Row(): rerank_btn = gr.Button("Rerank hybrid shortlist (cross-encoder)", variant="secondary") rerank_timer = gr.Markdown("") gr.Markdown("## Vector space (UMAP projection)") plot = gr.Plot(value=make_plot(config.DEFAULT_MODEL_KEY)) gr.Markdown(PLOT_CAVEAT) with gr.Accordion("Embed your own text", open=False): gr.Markdown( "Type anything — it gets embedded with the selected model and dropped onto the " "map. This *is* the space the database searches through." ) own_text = gr.Textbox(label="Your text", placeholder="e.g. a sentence about gene regulation…") own_btn = gr.Button("Embed & plot") own_btn.click(run_embed_own, inputs=[own_text, model_radio], outputs=plot) search_state = gr.State() search_inputs = [query, k, year_lo, year_hi, journal, model_radio] search_outputs = [bm25_out, dense_out, hybrid_out, reranked_out, plot, filter_info, search_state, rerank_timer] search_btn.click(run_search, inputs=search_inputs, outputs=search_outputs) query.submit(run_search, inputs=search_inputs, outputs=search_outputs) rerank_btn.click(run_rerank, inputs=[search_state, k], outputs=[reranked_out, plot, rerank_timer]) if __name__ == "__main__": demo.launch()