vector_demo / app.py
afg1's picture
Add cross-encoder rerank button + dual embedding models (general/biomedical) with live switch
a2c6853 verified
Raw
History Blame Contribute Delete
19.3 kB
"""Retrieval-methods comparison demo (BM25 vs Dense vs Hybrid/RRF, + optional reranking).
A teaching app for EMBL-EBI researchers. Reads pre-built artifacts from ./data/ (made by
build_index.py) and only embeds the user's LIVE query at runtime. Read top-to-bottom:
artifacts -> embed query (GPU) -> filter -> 3 rankings -> RRF -> UMAP plot
|
optional: cross-encoder rerank (GPU)
Two embedding models are shipped (general MiniLM + biomedical PubMedBERT); a radio switches
which one Dense/Hybrid/plot use (BM25 is lexical, so it's unaffected).
ZeroGPU note: every model is loaded on CPU at import. GPU is touched ONLY inside the
@spaces.GPU functions. Do not move a model to CUDA anywhere else.
"""
import html
import json
import os
import time
import gradio as gr
import joblib
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import spaces
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder, SentenceTransformer
import config
from text_utils import tokenize
# ====================================================================================
# Load pre-built artifacts (fast; no embedding, no network)
# ====================================================================================
D = config.DATA_DIR
# Shared, model-independent artifacts.
META = pd.read_parquet(os.path.join(D, "metadata.parquet"))
with open(os.path.join(D, "bm25_tokens.json")) as f:
BM25 = BM25Okapi(json.load(f))
TITLES = META["title"].tolist()
ABSTRACTS = META["abstract"].tolist()
YEARS = META["year"].fillna(0).astype(int).to_numpy()
JOURNALS = META["journal"].astype(str).to_numpy()
YEAR_MIN = int(YEARS[YEARS > 0].min())
YEAR_MAX = int(YEARS.max())
# 1300+ distinct journals — show only the most common ones in the dropdown.
TOP_JOURNALS = META["journal"].value_counts().head(30).index.tolist()
# Per-model artifacts + models, keyed by model_key (see config.EMBEDDING_MODELS).
EMB, COORDS, REDUCER, MODELS = {}, {}, {}, {}
for _key, (_label, _name) in config.EMBEDDING_MODELS.items():
EMB[_key] = np.load(os.path.join(D, f"embeddings_{_key}.npy")) # (N, dim), L2-normalised
COORDS[_key] = np.load(os.path.join(D, f"umap_coords_{_key}.npy")) # (N, 2)
REDUCER[_key] = joblib.load(os.path.join(D, f"umap_{_key}.joblib")) # .transform new points
MODELS[_key] = SentenceTransformer(_name, device="cpu") # CPU only at import!
_gd = getattr(MODELS[_key], "get_embedding_dimension", None) or MODELS[_key].get_sentence_embedding_dimension
assert _gd() == EMB[_key].shape[1], f"Model/embedding dim mismatch for '{_key}' — rebuild?"
# Cross-encoder reranker — also CPU at import; moved to GPU inside @spaces.GPU.
RERANKER = CrossEncoder(config.RERANKER_MODEL, device="cpu")
# UI label <-> internal key.
MODEL_LABELS = {k: lbl for k, (lbl, _) in config.EMBEDDING_MODELS.items()}
LABEL_TO_KEY = {lbl: k for k, lbl in MODEL_LABELS.items()}
DEFAULT_LABEL = MODEL_LABELS[config.DEFAULT_MODEL_KEY]
# Per-method plot styling. Sizes/widths are tiered so that a doc retrieved by several
# methods nests into concentric rings instead of one trace hiding another.
METHOD_COLORS = {"BM25": "#ff7f0e", "Dense": "#1f77b4", "Hybrid": "#2ca02c", "Reranked": "#111111"}
METHOD_SIZES = {"BM25": 20, "Dense": 15, "Hybrid": 10, "Reranked": 13}
METHOD_WIDTHS = {"BM25": 4, "Dense": 2.5, "Hybrid": 1, "Reranked": 2.5}
# ====================================================================================
# GPU functions — the ONLY place CUDA is touched
# ====================================================================================
def _encode(model_key: str, text: str) -> np.ndarray:
"""Embed one string with the chosen model -> (dim,) float32, L2-normalised."""
device = "cuda" if torch.cuda.is_available() else "cpu" # safe: only under @spaces.GPU
model = MODELS[model_key]
model.to(device)
vec = model.encode([text], normalize_embeddings=True, convert_to_numpy=True, device=device)
return vec[0].astype(np.float32)
@spaces.GPU
def embed_query(model_key: str, text: str) -> np.ndarray:
"""GPU-backed query embedding for retrieval."""
return _encode(model_key, text)
@spaces.GPU
def embed_text(model_key: str, text: str) -> np.ndarray:
"""GPU-backed embedding for the 'embed your own text' feature."""
return _encode(model_key, text)
@spaces.GPU
def rerank_scores(query: str, texts: list[str]) -> np.ndarray:
"""Cross-encoder relevance scores for (query, doc) pairs. One transformer forward
pass PER candidate — this is the expensive, can't-precompute step."""
device = "cuda" if torch.cuda.is_available() else "cpu"
RERANKER.model.to(device) # CrossEncoder.device follows the underlying model
pairs = [(query, t) for t in texts]
scores = RERANKER.predict(pairs, batch_size=16, convert_to_numpy=True, show_progress_bar=False)
return np.asarray(scores, dtype=np.float32)
# ====================================================================================
# Retrieval (pure CPU / numpy)
# ====================================================================================
def candidate_indices(year_lo: int, year_hi: int, journal: str) -> np.ndarray:
"""Apply the metadata filter BEFORE retrieval. Returns indices of allowed docs.
This pre-filtering of the candidate pool is the distinctly vector-DB feature: every
method then searches only within these documents.
"""
mask = (YEARS >= int(year_lo)) & (YEARS <= int(year_hi))
if journal and journal != "All journals":
mask &= JOURNALS == journal
return np.flatnonzero(mask)
def top_k(scores: np.ndarray, cand: np.ndarray, k: int) -> list[tuple[int, float]]:
"""Top-k (doc_index, score) among candidates, ranked by score descending."""
cand_scores = scores[cand]
order = np.argsort(-cand_scores)[:k]
return [(int(cand[i]), float(cand_scores[i])) for i in order]
def rrf_fuse(bm25: np.ndarray, dense: np.ndarray, cand: np.ndarray, k: int):
"""Reciprocal Rank Fusion over the FULL candidate rankings, then take top-k.
Rank every candidate by each method, sum 1/(RRF_K + rank) across methods, and only
then truncate. (Truncating each method first would drop docs ranked low by one method
but high by the other.) No score normalisation.
"""
def ranks(scores):
order = np.argsort(-scores[cand]) # candidate positions, best first
r = np.empty(len(cand), dtype=int)
r[order] = np.arange(len(cand)) # 0-based rank per candidate position
return r
rb, rd = ranks(bm25), ranks(dense)
fused = 1.0 / (config.RRF_K + rb) + 1.0 / (config.RRF_K + rd)
order = np.argsort(-fused)[:k]
return [(int(cand[i]), float(fused[i])) for i in order]
def format_results(method: str, results: list[tuple[int, float]], score_label: str) -> str:
"""Render a ranked list as HTML. Each result is a <details> — click the title to
expand the FULL abstract. This list is the SOURCE OF TRUTH for retrieval.
"""
if not results:
return f"<h3>{method}</h3><p><em>No results.</em></p>"
parts = [f"<h3>{method}</h3>"]
for rank, (doc, score) in enumerate(results, start=1):
title = html.escape(TITLES[doc]) # abstracts contain <, >, & — must escape
abstract = html.escape(ABSTRACTS[doc])
journal = html.escape(str(JOURNALS[doc]))
parts.append(
"<details style='margin-bottom:10px;border-bottom:1px solid #ddd;padding-bottom:6px;'>"
f"<summary style='cursor:pointer;'><b>{rank}. {title}</b><br>"
f"<code>{score_label}={score:.3f}</code> · {int(YEARS[doc])} · <em>{journal}</em>"
"</summary>"
f"<p style='margin-top:6px;font-size:0.9em;line-height:1.4;'>{abstract}</p>"
"</details>"
)
return "".join(parts)
def make_plot(model_key, query_coord=None, retrieved=None, extra_point=None, cand=None) -> go.Figure:
"""UMAP scatter for the chosen model, plus query point + connector lines to hits.
`cand` (optional) = indices that passed the metadata filter; when given, only those
points form the grey background, so the plot reflects the filter.
"""
coords = COORDS[model_key]
fig = go.Figure()
bg = np.arange(len(coords)) if cand is None else np.asarray(cand)
fig.add_trace(
go.Scatter(
x=coords[bg, 0], y=coords[bg, 1], mode="markers",
marker=dict(size=4, color="lightgrey"),
text=[TITLES[i] for i in bg], hoverinfo="text", name=f"corpus ({len(bg)})",
)
)
if retrieved and query_coord is not None:
for method, results in retrieved.items():
color = METHOD_COLORS[method]
xs, ys = [], []
for doc, _ in results: # connector lines query -> each hit
xs += [query_coord[0], coords[doc, 0], None]
ys += [query_coord[1], coords[doc, 1], None]
fig.add_trace(
go.Scatter(x=xs, y=ys, mode="lines",
line=dict(color=color, width=METHOD_WIDTHS[method]),
opacity=0.5, name=f"{method} links", hoverinfo="skip")
)
fig.add_trace(
go.Scatter(
x=[coords[d, 0] for d, _ in results],
y=[coords[d, 1] for d, _ in results],
mode="markers",
marker=dict(size=METHOD_SIZES[method], color=color, symbol="circle-open",
line=dict(width=2)),
text=[TITLES[d] for d, _ in results], hoverinfo="text", name=f"{method} hits",
)
)
if query_coord is not None:
fig.add_trace(
go.Scatter(x=[query_coord[0]], y=[query_coord[1]], mode="markers",
marker=dict(size=18, color="red", symbol="star"),
text=["your query"], hoverinfo="text", name="query")
)
if extra_point is not None:
coord, label = extra_point
fig.add_trace(
go.Scatter(x=[coord[0]], y=[coord[1]], mode="markers",
marker=dict(size=15, color="purple", symbol="diamond"),
text=[label], hoverinfo="text", name="your text")
)
fig.update_layout(
margin=dict(l=10, r=10, t=30, b=10), height=520,
xaxis_title="UMAP-1", yaxis_title="UMAP-2",
legend=dict(orientation="h", yanchor="bottom", y=1.0),
)
return fig
# ====================================================================================
# Gradio handlers
# ====================================================================================
RERANK_HINT = "<p><em>Press <b>Rerank</b> to reorder the hybrid shortlist with the cross-encoder.</em></p>"
def run_search(query, k, year_lo, year_hi, journal, model_label):
"""Stage 1: BM25 + Dense + Hybrid. Fast. Stashes a shortlist for optional reranking."""
key = LABEL_TO_KEY[model_label]
query = (query or "").strip()
if not query:
return "Enter a query.", "", "", "", make_plot(key), "—", None, ""
cand = candidate_indices(year_lo, year_hi, journal)
info = f"**Candidate pool: {len(cand)} / {len(META)} abstracts** after filtering."
if len(cand) == 0:
return "No documents match the filter.", "", "", "", make_plot(key, cand=cand), info, None, ""
k = int(k)
qvec = embed_query(key, query) # <-- GPU work (live query only)
dense_scores = EMB[key] @ qvec # exact cosine (vectors are normalised)
bm25_scores = np.asarray(BM25.get_scores(tokenize(query)))
bm25_top = top_k(bm25_scores, cand, k)
dense_top = top_k(dense_scores, cand, k)
hybrid_top = rrf_fuse(bm25_scores, dense_scores, cand, k)
# Shortlist for the reranker = hybrid's top-N (first-stage retrieval feeds stage 2).
shortlist = [d for d, _ in rrf_fuse(bm25_scores, dense_scores, cand, config.RERANK_CANDIDATES)]
query_coord = REDUCER[key].transform(qvec.reshape(1, -1))[0]
fig = make_plot(key, query_coord, {"BM25": bm25_top, "Dense": dense_top, "Hybrid": hybrid_top}, cand=cand)
state = {
"key": key, "query": query,
"query_coord": [float(query_coord[0]), float(query_coord[1])],
"cand": cand.tolist(), "shortlist": shortlist,
"hybrid_top": [[int(d), float(s)] for d, s in hybrid_top],
}
return (
format_results("BM25", bm25_top, "score"),
format_results("Dense", dense_top, "cosine"),
format_results("Hybrid", hybrid_top, "RRF"),
RERANK_HINT,
fig, info, state, "",
)
def run_rerank(state, k):
"""Stage 2: cross-encoder reranks the hybrid shortlist. Slow — and we time it."""
if not state:
return "<p><em>Run a search first.</em></p>", gr.update(), ""
key, query, shortlist = state["key"], state["query"], state["shortlist"]
texts = [f"{TITLES[i]}. {ABSTRACTS[i]}" for i in shortlist]
t0 = time.time()
scores = rerank_scores(query, texts) # <-- GPU work, one pass per candidate
dt = time.time() - t0
order = np.argsort(-scores)[: int(k)]
reranked = [(int(shortlist[i]), float(scores[i])) for i in order]
hybrid_top = [(int(d), float(s)) for d, s in state["hybrid_top"]]
fig = make_plot(key, state["query_coord"], {"Hybrid": hybrid_top, "Reranked": reranked},
cand=np.asarray(state["cand"]))
timer = (
f"⏱️ Reranked **{len(shortlist)}** candidates with the cross-encoder in "
f"**{dt:.2f}s**. It ran a transformer on every (query, document) pair at query "
f"time — contrast the instant BM25/Dense/Hybrid. The plot shows Hybrid (green) "
f"vs Reranked (black)."
)
return format_results("Reranked", reranked, "cross-enc"), fig, timer
def run_embed_own(text, model_label):
key = LABEL_TO_KEY[model_label]
text = (text or "").strip()
if not text:
return make_plot(key)
vec = embed_text(key, text) # <-- GPU work
coord = REDUCER[key].transform(vec.reshape(1, -1))[0]
return make_plot(key, extra_point=(coord, text[:80]))
# ====================================================================================
# Example queries — tuned to the cardiovascular slice of this microRNA/disease corpus.
# Each is chosen (and verified against the corpus) to make ONE method visibly win.
# ====================================================================================
EXACT_ID_QUERY = "miR-208a" # BM25 wins: exact identifier token, matched verbatim.
# Dense wins: plain-English paraphrase using none of the corpus jargon.
PARAPHRASE_QUERY = "small RNA molecules that worsen scarring after a heart attack"
ACRONYM_QUERY = "AMI" # BM25 wins: bare acronym is a lexical token; dense has little to grip.
RARE_TERM_QUERY = "miR-499 cardiomyocyte apoptosis" # BM25 pins the specific molecule.
# Hybrid wins: broad query where lexical + semantic each surface good-but-different papers.
BROAD_CONCEPT_QUERY = "circulating microRNA biomarkers for cardiovascular disease"
EXAMPLES = [
("Exact ID (BM25)", EXACT_ID_QUERY),
("Paraphrase (Dense)", PARAPHRASE_QUERY),
("Acronym (BM25)", ACRONYM_QUERY),
("Specific molecule (BM25)", RARE_TERM_QUERY),
("Broad concept (Hybrid)", BROAD_CONCEPT_QUERY),
]
PLOT_CAVEAT = (
"⚠️ **The 2D UMAP projection distorts true distances.** UMAP preserves rough local "
"neighbourhoods but warps global distances, and the live query is placed by an "
"*approximate* out-of-sample fit — so its position (and the connector lines) are only "
"indicative. Retrieved points may *not* be the visually-closest dots. "
"**The ranked lists above are authoritative**; the plot is only for intuition."
)
# ====================================================================================
# UI
# ====================================================================================
with gr.Blocks(title="RAG retrieval: BM25 vs Dense vs Hybrid vs Rerank") as demo:
gr.Markdown(
"# 🔍 Retrieval methods, side by side\n"
"Compare **BM25** (lexical), **Dense** (vector cosine), **Hybrid** (Reciprocal "
"Rank Fusion), and an optional **cross-encoder rerank** over a corpus of "
f"{len(META):,} Europe PMC abstracts on microRNA & disease."
)
with gr.Row():
query = gr.Textbox(label="Query", placeholder="Type a search query…", scale=4)
k = gr.Slider(1, 10, value=5, step=1, label="Top-k", scale=1)
with gr.Row():
for label, text in EXAMPLES:
btn = gr.Button(label, size="sm")
btn.click(lambda t=text: t, outputs=query)
model_radio = gr.Radio(
choices=list(LABEL_TO_KEY.keys()), value=DEFAULT_LABEL,
label="Embedding model — used by Dense, Hybrid and the plot (BM25 is lexical, so it never changes)",
)
with gr.Group():
gr.Markdown("**Metadata filter** — restricts the candidate pool *before* retrieval (all methods).")
with gr.Row():
year_lo = gr.Slider(YEAR_MIN, YEAR_MAX, value=YEAR_MIN, step=1, label="Year from")
year_hi = gr.Slider(YEAR_MIN, YEAR_MAX, value=YEAR_MAX, step=1, label="Year to")
journal = gr.Dropdown(["All journals"] + TOP_JOURNALS, value="All journals", label="Journal")
search_btn = gr.Button("Search", variant="primary")
filter_info = gr.Markdown("—")
gr.Markdown("_Click any result to expand its full abstract._")
with gr.Row():
bm25_out = gr.HTML(label="BM25")
dense_out = gr.HTML(label="Dense")
hybrid_out = gr.HTML(label="Hybrid")
reranked_out = gr.HTML(label="Reranked")
with gr.Row():
rerank_btn = gr.Button("Rerank hybrid shortlist (cross-encoder)", variant="secondary")
rerank_timer = gr.Markdown("")
gr.Markdown("## Vector space (UMAP projection)")
plot = gr.Plot(value=make_plot(config.DEFAULT_MODEL_KEY))
gr.Markdown(PLOT_CAVEAT)
with gr.Accordion("Embed your own text", open=False):
gr.Markdown(
"Type anything — it gets embedded with the selected model and dropped onto the "
"map. This *is* the space the database searches through."
)
own_text = gr.Textbox(label="Your text", placeholder="e.g. a sentence about gene regulation…")
own_btn = gr.Button("Embed & plot")
own_btn.click(run_embed_own, inputs=[own_text, model_radio], outputs=plot)
search_state = gr.State()
search_inputs = [query, k, year_lo, year_hi, journal, model_radio]
search_outputs = [bm25_out, dense_out, hybrid_out, reranked_out, plot, filter_info, search_state, rerank_timer]
search_btn.click(run_search, inputs=search_inputs, outputs=search_outputs)
query.submit(run_search, inputs=search_inputs, outputs=search_outputs)
rerank_btn.click(run_rerank, inputs=[search_state, k], outputs=[reranked_out, plot, rerank_timer])
if __name__ == "__main__":
demo.launch()