Spaces:
Runtime error
Runtime error
| """Retrieval-methods comparison demo (BM25 vs Dense vs Hybrid/RRF, + optional reranking). | |
| A teaching app for EMBL-EBI researchers. Reads pre-built artifacts from ./data/ (made by | |
| build_index.py) and only embeds the user's LIVE query at runtime. Read top-to-bottom: | |
| artifacts -> embed query (GPU) -> filter -> 3 rankings -> RRF -> UMAP plot | |
| | | |
| optional: cross-encoder rerank (GPU) | |
| Two embedding models are shipped (general MiniLM + biomedical PubMedBERT); a radio switches | |
| which one Dense/Hybrid/plot use (BM25 is lexical, so it's unaffected). | |
| ZeroGPU note: every model is loaded on CPU at import. GPU is touched ONLY inside the | |
| @spaces.GPU functions. Do not move a model to CUDA anywhere else. | |
| """ | |
| import html | |
| import json | |
| import os | |
| import time | |
| import gradio as gr | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import spaces | |
| import torch | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| import config | |
| from text_utils import tokenize | |
| # ==================================================================================== | |
| # Load pre-built artifacts (fast; no embedding, no network) | |
| # ==================================================================================== | |
| D = config.DATA_DIR | |
| # Shared, model-independent artifacts. | |
| META = pd.read_parquet(os.path.join(D, "metadata.parquet")) | |
| with open(os.path.join(D, "bm25_tokens.json")) as f: | |
| BM25 = BM25Okapi(json.load(f)) | |
| TITLES = META["title"].tolist() | |
| ABSTRACTS = META["abstract"].tolist() | |
| YEARS = META["year"].fillna(0).astype(int).to_numpy() | |
| JOURNALS = META["journal"].astype(str).to_numpy() | |
| YEAR_MIN = int(YEARS[YEARS > 0].min()) | |
| YEAR_MAX = int(YEARS.max()) | |
| # 1300+ distinct journals — show only the most common ones in the dropdown. | |
| TOP_JOURNALS = META["journal"].value_counts().head(30).index.tolist() | |
| # Per-model artifacts + models, keyed by model_key (see config.EMBEDDING_MODELS). | |
| EMB, COORDS, REDUCER, MODELS = {}, {}, {}, {} | |
| for _key, (_label, _name) in config.EMBEDDING_MODELS.items(): | |
| EMB[_key] = np.load(os.path.join(D, f"embeddings_{_key}.npy")) # (N, dim), L2-normalised | |
| COORDS[_key] = np.load(os.path.join(D, f"umap_coords_{_key}.npy")) # (N, 2) | |
| REDUCER[_key] = joblib.load(os.path.join(D, f"umap_{_key}.joblib")) # .transform new points | |
| MODELS[_key] = SentenceTransformer(_name, device="cpu") # CPU only at import! | |
| _gd = getattr(MODELS[_key], "get_embedding_dimension", None) or MODELS[_key].get_sentence_embedding_dimension | |
| assert _gd() == EMB[_key].shape[1], f"Model/embedding dim mismatch for '{_key}' — rebuild?" | |
| # Cross-encoder reranker — also CPU at import; moved to GPU inside @spaces.GPU. | |
| RERANKER = CrossEncoder(config.RERANKER_MODEL, device="cpu") | |
| # UI label <-> internal key. | |
| MODEL_LABELS = {k: lbl for k, (lbl, _) in config.EMBEDDING_MODELS.items()} | |
| LABEL_TO_KEY = {lbl: k for k, lbl in MODEL_LABELS.items()} | |
| DEFAULT_LABEL = MODEL_LABELS[config.DEFAULT_MODEL_KEY] | |
| # Per-method plot styling. Sizes/widths are tiered so that a doc retrieved by several | |
| # methods nests into concentric rings instead of one trace hiding another. | |
| METHOD_COLORS = {"BM25": "#ff7f0e", "Dense": "#1f77b4", "Hybrid": "#2ca02c", "Reranked": "#111111"} | |
| METHOD_SIZES = {"BM25": 20, "Dense": 15, "Hybrid": 10, "Reranked": 13} | |
| METHOD_WIDTHS = {"BM25": 4, "Dense": 2.5, "Hybrid": 1, "Reranked": 2.5} | |
| # ==================================================================================== | |
| # GPU functions — the ONLY place CUDA is touched | |
| # ==================================================================================== | |
| def _encode(model_key: str, text: str) -> np.ndarray: | |
| """Embed one string with the chosen model -> (dim,) float32, L2-normalised.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" # safe: only under @spaces.GPU | |
| model = MODELS[model_key] | |
| model.to(device) | |
| vec = model.encode([text], normalize_embeddings=True, convert_to_numpy=True, device=device) | |
| return vec[0].astype(np.float32) | |
| def embed_query(model_key: str, text: str) -> np.ndarray: | |
| """GPU-backed query embedding for retrieval.""" | |
| return _encode(model_key, text) | |
| def embed_text(model_key: str, text: str) -> np.ndarray: | |
| """GPU-backed embedding for the 'embed your own text' feature.""" | |
| return _encode(model_key, text) | |
| def rerank_scores(query: str, texts: list[str]) -> np.ndarray: | |
| """Cross-encoder relevance scores for (query, doc) pairs. One transformer forward | |
| pass PER candidate — this is the expensive, can't-precompute step.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| RERANKER.model.to(device) # CrossEncoder.device follows the underlying model | |
| pairs = [(query, t) for t in texts] | |
| scores = RERANKER.predict(pairs, batch_size=16, convert_to_numpy=True, show_progress_bar=False) | |
| return np.asarray(scores, dtype=np.float32) | |
| # ==================================================================================== | |
| # Retrieval (pure CPU / numpy) | |
| # ==================================================================================== | |
| def candidate_indices(year_lo: int, year_hi: int, journal: str) -> np.ndarray: | |
| """Apply the metadata filter BEFORE retrieval. Returns indices of allowed docs. | |
| This pre-filtering of the candidate pool is the distinctly vector-DB feature: every | |
| method then searches only within these documents. | |
| """ | |
| mask = (YEARS >= int(year_lo)) & (YEARS <= int(year_hi)) | |
| if journal and journal != "All journals": | |
| mask &= JOURNALS == journal | |
| return np.flatnonzero(mask) | |
| def top_k(scores: np.ndarray, cand: np.ndarray, k: int) -> list[tuple[int, float]]: | |
| """Top-k (doc_index, score) among candidates, ranked by score descending.""" | |
| cand_scores = scores[cand] | |
| order = np.argsort(-cand_scores)[:k] | |
| return [(int(cand[i]), float(cand_scores[i])) for i in order] | |
| def rrf_fuse(bm25: np.ndarray, dense: np.ndarray, cand: np.ndarray, k: int): | |
| """Reciprocal Rank Fusion over the FULL candidate rankings, then take top-k. | |
| Rank every candidate by each method, sum 1/(RRF_K + rank) across methods, and only | |
| then truncate. (Truncating each method first would drop docs ranked low by one method | |
| but high by the other.) No score normalisation. | |
| """ | |
| def ranks(scores): | |
| order = np.argsort(-scores[cand]) # candidate positions, best first | |
| r = np.empty(len(cand), dtype=int) | |
| r[order] = np.arange(len(cand)) # 0-based rank per candidate position | |
| return r | |
| rb, rd = ranks(bm25), ranks(dense) | |
| fused = 1.0 / (config.RRF_K + rb) + 1.0 / (config.RRF_K + rd) | |
| order = np.argsort(-fused)[:k] | |
| return [(int(cand[i]), float(fused[i])) for i in order] | |
| def format_results(method: str, results: list[tuple[int, float]], score_label: str) -> str: | |
| """Render a ranked list as HTML. Each result is a <details> — click the title to | |
| expand the FULL abstract. This list is the SOURCE OF TRUTH for retrieval. | |
| """ | |
| if not results: | |
| return f"<h3>{method}</h3><p><em>No results.</em></p>" | |
| parts = [f"<h3>{method}</h3>"] | |
| for rank, (doc, score) in enumerate(results, start=1): | |
| title = html.escape(TITLES[doc]) # abstracts contain <, >, & — must escape | |
| abstract = html.escape(ABSTRACTS[doc]) | |
| journal = html.escape(str(JOURNALS[doc])) | |
| parts.append( | |
| "<details style='margin-bottom:10px;border-bottom:1px solid #ddd;padding-bottom:6px;'>" | |
| f"<summary style='cursor:pointer;'><b>{rank}. {title}</b><br>" | |
| f"<code>{score_label}={score:.3f}</code> · {int(YEARS[doc])} · <em>{journal}</em>" | |
| "</summary>" | |
| f"<p style='margin-top:6px;font-size:0.9em;line-height:1.4;'>{abstract}</p>" | |
| "</details>" | |
| ) | |
| return "".join(parts) | |
| def make_plot(model_key, query_coord=None, retrieved=None, extra_point=None, cand=None) -> go.Figure: | |
| """UMAP scatter for the chosen model, plus query point + connector lines to hits. | |
| `cand` (optional) = indices that passed the metadata filter; when given, only those | |
| points form the grey background, so the plot reflects the filter. | |
| """ | |
| coords = COORDS[model_key] | |
| fig = go.Figure() | |
| bg = np.arange(len(coords)) if cand is None else np.asarray(cand) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=coords[bg, 0], y=coords[bg, 1], mode="markers", | |
| marker=dict(size=4, color="lightgrey"), | |
| text=[TITLES[i] for i in bg], hoverinfo="text", name=f"corpus ({len(bg)})", | |
| ) | |
| ) | |
| if retrieved and query_coord is not None: | |
| for method, results in retrieved.items(): | |
| color = METHOD_COLORS[method] | |
| xs, ys = [], [] | |
| for doc, _ in results: # connector lines query -> each hit | |
| xs += [query_coord[0], coords[doc, 0], None] | |
| ys += [query_coord[1], coords[doc, 1], None] | |
| fig.add_trace( | |
| go.Scatter(x=xs, y=ys, mode="lines", | |
| line=dict(color=color, width=METHOD_WIDTHS[method]), | |
| opacity=0.5, name=f"{method} links", hoverinfo="skip") | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[coords[d, 0] for d, _ in results], | |
| y=[coords[d, 1] for d, _ in results], | |
| mode="markers", | |
| marker=dict(size=METHOD_SIZES[method], color=color, symbol="circle-open", | |
| line=dict(width=2)), | |
| text=[TITLES[d] for d, _ in results], hoverinfo="text", name=f"{method} hits", | |
| ) | |
| ) | |
| if query_coord is not None: | |
| fig.add_trace( | |
| go.Scatter(x=[query_coord[0]], y=[query_coord[1]], mode="markers", | |
| marker=dict(size=18, color="red", symbol="star"), | |
| text=["your query"], hoverinfo="text", name="query") | |
| ) | |
| if extra_point is not None: | |
| coord, label = extra_point | |
| fig.add_trace( | |
| go.Scatter(x=[coord[0]], y=[coord[1]], mode="markers", | |
| marker=dict(size=15, color="purple", symbol="diamond"), | |
| text=[label], hoverinfo="text", name="your text") | |
| ) | |
| fig.update_layout( | |
| margin=dict(l=10, r=10, t=30, b=10), height=520, | |
| xaxis_title="UMAP-1", yaxis_title="UMAP-2", | |
| legend=dict(orientation="h", yanchor="bottom", y=1.0), | |
| ) | |
| return fig | |
| # ==================================================================================== | |
| # Gradio handlers | |
| # ==================================================================================== | |
| RERANK_HINT = "<p><em>Press <b>Rerank</b> to reorder the hybrid shortlist with the cross-encoder.</em></p>" | |
| def run_search(query, k, year_lo, year_hi, journal, model_label): | |
| """Stage 1: BM25 + Dense + Hybrid. Fast. Stashes a shortlist for optional reranking.""" | |
| key = LABEL_TO_KEY[model_label] | |
| query = (query or "").strip() | |
| if not query: | |
| return "Enter a query.", "", "", "", make_plot(key), "—", None, "" | |
| cand = candidate_indices(year_lo, year_hi, journal) | |
| info = f"**Candidate pool: {len(cand)} / {len(META)} abstracts** after filtering." | |
| if len(cand) == 0: | |
| return "No documents match the filter.", "", "", "", make_plot(key, cand=cand), info, None, "" | |
| k = int(k) | |
| qvec = embed_query(key, query) # <-- GPU work (live query only) | |
| dense_scores = EMB[key] @ qvec # exact cosine (vectors are normalised) | |
| bm25_scores = np.asarray(BM25.get_scores(tokenize(query))) | |
| bm25_top = top_k(bm25_scores, cand, k) | |
| dense_top = top_k(dense_scores, cand, k) | |
| hybrid_top = rrf_fuse(bm25_scores, dense_scores, cand, k) | |
| # Shortlist for the reranker = hybrid's top-N (first-stage retrieval feeds stage 2). | |
| shortlist = [d for d, _ in rrf_fuse(bm25_scores, dense_scores, cand, config.RERANK_CANDIDATES)] | |
| query_coord = REDUCER[key].transform(qvec.reshape(1, -1))[0] | |
| fig = make_plot(key, query_coord, {"BM25": bm25_top, "Dense": dense_top, "Hybrid": hybrid_top}, cand=cand) | |
| state = { | |
| "key": key, "query": query, | |
| "query_coord": [float(query_coord[0]), float(query_coord[1])], | |
| "cand": cand.tolist(), "shortlist": shortlist, | |
| "hybrid_top": [[int(d), float(s)] for d, s in hybrid_top], | |
| } | |
| return ( | |
| format_results("BM25", bm25_top, "score"), | |
| format_results("Dense", dense_top, "cosine"), | |
| format_results("Hybrid", hybrid_top, "RRF"), | |
| RERANK_HINT, | |
| fig, info, state, "", | |
| ) | |
| def run_rerank(state, k): | |
| """Stage 2: cross-encoder reranks the hybrid shortlist. Slow — and we time it.""" | |
| if not state: | |
| return "<p><em>Run a search first.</em></p>", gr.update(), "" | |
| key, query, shortlist = state["key"], state["query"], state["shortlist"] | |
| texts = [f"{TITLES[i]}. {ABSTRACTS[i]}" for i in shortlist] | |
| t0 = time.time() | |
| scores = rerank_scores(query, texts) # <-- GPU work, one pass per candidate | |
| dt = time.time() - t0 | |
| order = np.argsort(-scores)[: int(k)] | |
| reranked = [(int(shortlist[i]), float(scores[i])) for i in order] | |
| hybrid_top = [(int(d), float(s)) for d, s in state["hybrid_top"]] | |
| fig = make_plot(key, state["query_coord"], {"Hybrid": hybrid_top, "Reranked": reranked}, | |
| cand=np.asarray(state["cand"])) | |
| timer = ( | |
| f"⏱️ Reranked **{len(shortlist)}** candidates with the cross-encoder in " | |
| f"**{dt:.2f}s**. It ran a transformer on every (query, document) pair at query " | |
| f"time — contrast the instant BM25/Dense/Hybrid. The plot shows Hybrid (green) " | |
| f"vs Reranked (black)." | |
| ) | |
| return format_results("Reranked", reranked, "cross-enc"), fig, timer | |
| def run_embed_own(text, model_label): | |
| key = LABEL_TO_KEY[model_label] | |
| text = (text or "").strip() | |
| if not text: | |
| return make_plot(key) | |
| vec = embed_text(key, text) # <-- GPU work | |
| coord = REDUCER[key].transform(vec.reshape(1, -1))[0] | |
| return make_plot(key, extra_point=(coord, text[:80])) | |
| # ==================================================================================== | |
| # Example queries — tuned to the cardiovascular slice of this microRNA/disease corpus. | |
| # Each is chosen (and verified against the corpus) to make ONE method visibly win. | |
| # ==================================================================================== | |
| EXACT_ID_QUERY = "miR-208a" # BM25 wins: exact identifier token, matched verbatim. | |
| # Dense wins: plain-English paraphrase using none of the corpus jargon. | |
| PARAPHRASE_QUERY = "small RNA molecules that worsen scarring after a heart attack" | |
| ACRONYM_QUERY = "AMI" # BM25 wins: bare acronym is a lexical token; dense has little to grip. | |
| RARE_TERM_QUERY = "miR-499 cardiomyocyte apoptosis" # BM25 pins the specific molecule. | |
| # Hybrid wins: broad query where lexical + semantic each surface good-but-different papers. | |
| BROAD_CONCEPT_QUERY = "circulating microRNA biomarkers for cardiovascular disease" | |
| EXAMPLES = [ | |
| ("Exact ID (BM25)", EXACT_ID_QUERY), | |
| ("Paraphrase (Dense)", PARAPHRASE_QUERY), | |
| ("Acronym (BM25)", ACRONYM_QUERY), | |
| ("Specific molecule (BM25)", RARE_TERM_QUERY), | |
| ("Broad concept (Hybrid)", BROAD_CONCEPT_QUERY), | |
| ] | |
| PLOT_CAVEAT = ( | |
| "⚠️ **The 2D UMAP projection distorts true distances.** UMAP preserves rough local " | |
| "neighbourhoods but warps global distances, and the live query is placed by an " | |
| "*approximate* out-of-sample fit — so its position (and the connector lines) are only " | |
| "indicative. Retrieved points may *not* be the visually-closest dots. " | |
| "**The ranked lists above are authoritative**; the plot is only for intuition." | |
| ) | |
| # ==================================================================================== | |
| # UI | |
| # ==================================================================================== | |
| with gr.Blocks(title="RAG retrieval: BM25 vs Dense vs Hybrid vs Rerank") as demo: | |
| gr.Markdown( | |
| "# 🔍 Retrieval methods, side by side\n" | |
| "Compare **BM25** (lexical), **Dense** (vector cosine), **Hybrid** (Reciprocal " | |
| "Rank Fusion), and an optional **cross-encoder rerank** over a corpus of " | |
| f"{len(META):,} Europe PMC abstracts on microRNA & disease." | |
| ) | |
| with gr.Row(): | |
| query = gr.Textbox(label="Query", placeholder="Type a search query…", scale=4) | |
| k = gr.Slider(1, 10, value=5, step=1, label="Top-k", scale=1) | |
| with gr.Row(): | |
| for label, text in EXAMPLES: | |
| btn = gr.Button(label, size="sm") | |
| btn.click(lambda t=text: t, outputs=query) | |
| model_radio = gr.Radio( | |
| choices=list(LABEL_TO_KEY.keys()), value=DEFAULT_LABEL, | |
| label="Embedding model — used by Dense, Hybrid and the plot (BM25 is lexical, so it never changes)", | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("**Metadata filter** — restricts the candidate pool *before* retrieval (all methods).") | |
| with gr.Row(): | |
| year_lo = gr.Slider(YEAR_MIN, YEAR_MAX, value=YEAR_MIN, step=1, label="Year from") | |
| year_hi = gr.Slider(YEAR_MIN, YEAR_MAX, value=YEAR_MAX, step=1, label="Year to") | |
| journal = gr.Dropdown(["All journals"] + TOP_JOURNALS, value="All journals", label="Journal") | |
| search_btn = gr.Button("Search", variant="primary") | |
| filter_info = gr.Markdown("—") | |
| gr.Markdown("_Click any result to expand its full abstract._") | |
| with gr.Row(): | |
| bm25_out = gr.HTML(label="BM25") | |
| dense_out = gr.HTML(label="Dense") | |
| hybrid_out = gr.HTML(label="Hybrid") | |
| reranked_out = gr.HTML(label="Reranked") | |
| with gr.Row(): | |
| rerank_btn = gr.Button("Rerank hybrid shortlist (cross-encoder)", variant="secondary") | |
| rerank_timer = gr.Markdown("") | |
| gr.Markdown("## Vector space (UMAP projection)") | |
| plot = gr.Plot(value=make_plot(config.DEFAULT_MODEL_KEY)) | |
| gr.Markdown(PLOT_CAVEAT) | |
| with gr.Accordion("Embed your own text", open=False): | |
| gr.Markdown( | |
| "Type anything — it gets embedded with the selected model and dropped onto the " | |
| "map. This *is* the space the database searches through." | |
| ) | |
| own_text = gr.Textbox(label="Your text", placeholder="e.g. a sentence about gene regulation…") | |
| own_btn = gr.Button("Embed & plot") | |
| own_btn.click(run_embed_own, inputs=[own_text, model_radio], outputs=plot) | |
| search_state = gr.State() | |
| search_inputs = [query, k, year_lo, year_hi, journal, model_radio] | |
| search_outputs = [bm25_out, dense_out, hybrid_out, reranked_out, plot, filter_info, search_state, rerank_timer] | |
| search_btn.click(run_search, inputs=search_inputs, outputs=search_outputs) | |
| query.submit(run_search, inputs=search_inputs, outputs=search_outputs) | |
| rerank_btn.click(run_rerank, inputs=[search_state, k], outputs=[reranked_out, plot, rerank_timer]) | |
| if __name__ == "__main__": | |
| demo.launch() | |