Deploybot
Deploy from stable branch
9ad6332
"""Interactive demo for IBM Granite Embedding R2 models."""
from typing import Any
import gradio as gr
import numpy as np
import spaces
from sentence_transformers import CrossEncoder, SentenceTransformer
MODELS = {
"granite-embedding-small-english-r2": "ibm-granite/granite-embedding-small-english-r2",
"granite-embedding-97m-multilingual-r2": "ibm-granite/granite-embedding-97m-multilingual-r2",
}
MODEL_IDS = list(MODELS.values())
RERANK_MODEL_ID = "ibm-granite/granite-embedding-reranker-english-r2"
embedder = None
current_embed_model = None
reranker = None
def get_embedder(model_id: str) -> Any: # noqa: ANN401
"""Load or retrieve the embedding model."""
global embedder, current_embed_model
if embedder is None or current_embed_model != model_id:
embedder = SentenceTransformer(model_id)
current_embed_model = model_id
return embedder
def get_reranker() -> CrossEncoder:
"""Load or retrieve the reranker model."""
global reranker
if reranker is None:
reranker = CrossEncoder(RERANK_MODEL_ID)
return reranker
def cos_sim_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Compute cosine similarity matrix between two sets of embeddings."""
return a @ b.T
@spaces.GPU
def rank_passages(
query: str,
passages_text: str,
top_k: int,
retrieve_n: int,
use_reranker: bool,
model_name: str,
) -> list[list[Any]]:
"""Rank passages by similarity to query using embeddings and optional reranking."""
query = (query or "").strip()
passages = [p.strip() for p in (passages_text or "").splitlines() if p.strip()]
if not query:
return []
if not passages:
return []
embedder_model = get_embedder(model_name)
q_emb = embedder_model.encode([query], normalize_embeddings=True)
p_emb = embedder_model.encode(passages, normalize_embeddings=True)
sims = cos_sim_matrix(q_emb, p_emb)[0]
order = np.argsort(-sims)
retrieve_n = int(max(1, min(retrieve_n, len(passages))))
cand_idx = order[:retrieve_n].tolist()
rerank_scores_for_display = dict.fromkeys(cand_idx, None)
if use_reranker:
pairs = [(query, passages[i]) for i in cand_idx]
rr_scores = get_reranker().predict(pairs)
rr_order = np.argsort(-rr_scores)
cand_idx = [cand_idx[j] for j in rr_order]
for j, i in enumerate(cand_idx):
rerank_scores_for_display[i] = float(rr_scores[rr_order[j]])
top_k = int(max(1, min(top_k, len(cand_idx))))
rows = []
for rank, idx in enumerate(cand_idx[:top_k], start=1):
rows.append(
[
rank,
passages[idx],
float(sims[idx]),
rerank_scores_for_display.get(idx),
]
)
return rows
theme = gr.themes.Base(primary_hue=gr.themes.colors.blue)
with gr.Blocks(title="Granite Embedding Demo", theme=theme) as demo:
gr.HTML(
"<div style='text-align: center; margin-top: 5px; margin-bottom: 0;'><h1 style='font-size: 48px; margin: 0;'>Granite Embedding Demo</h1></div>" # noqa: E501
)
gr.HTML(
"<div style='display: flex; justify-content: center; gap: 16px; margin: -5px 0 15px 0;'>"
"<a href='https://huggingface.co/collections/ibm-granite/granite-embedding' target='_blank' style='padding: 4px 8px; border: 1px solid #e0e0e0; border-radius: 4px; text-decoration: none; color: inherit; font-size: 14px;' onmouseover=\"this.style.backgroundColor='#f0f0f0'\" onmouseout=\"this.style.backgroundColor='transparent'\">📚 Model Collection</a>" # noqa: E501
"<a href='https://github.com/ibm-granite/granite-embedding-models' target='_blank' style='padding: 4px 8px; border: 1px solid #e0e0e0; border-radius: 4px; text-decoration: none; color: inherit; font-size: 14px;' onmouseover=\"this.style.backgroundColor='#f0f0f0'\" onmouseout=\"this.style.backgroundColor='transparent'\">🔗 Repository</a>" # noqa: E501
"<a href='https://arxiv.org/abs/2508.21085' target='_blank' style='padding: 4px 8px; border: 1px solid #e0e0e0; border-radius: 4px; text-decoration: none; color: inherit; font-size: 14px;' onmouseover=\"this.style.backgroundColor='#f0f0f0'\" onmouseout=\"this.style.backgroundColor='transparent'\">📄 Paper</a>" # noqa: E501
"</div>"
)
with gr.Accordion("ℹ️ Learn more", open=False): # noqa: RUF001
gr.Markdown(
"**What are embedding models?**\n\n"
"Embedding models convert words and sentences into numeric "
"representations, letting us measure meaning rather than just matching text. Search engines use "
"embeddings to show you results about the same topic, even when the exact words differ.\n\n"
"**What are reranker models?**\n\n"
"Rerankers apply a second, more fine-tuned scoring step to re-rank "
"the top results from embeddings. They don't mean embeddings are inaccurate—they just add another layer "
"of precision to surface the best answers.\n\n"
"**What are top-N candidates?**\n\n"
'"Top-N candidates" are the initial set of results retrieved using '
"the embedding model. Out of a large collection, we quickly pick the N most similar items to your "
"query as likely matches. Choosing a larger N means we consider more possibilities, which can improve "
"accuracy (especially when using a re-ranker), but it also makes things slower. A smaller N is faster, "
"but risks missing some relevant results.\n\n"
"**What are top-K results?**\n\n"
'"Top-K results" are the final results we return to the user. After '
"optionally re-ranking the top-N candidates, we select the best K items to display. Choosing a smaller "
"K keeps results focused and easy to explore, while a larger K shows more options but may include less "
"relevant items. In short: N controls how much we consider, K controls how much you see."
)
with gr.Group():
with gr.Row():
model_name = gr.Dropdown(
choices=MODEL_IDS,
value="ibm-granite/granite-embedding-small-english-r2",
label="Embedding Model",
scale=2,
)
use_reranker = gr.Checkbox(
value=True, label="Enable reranking", info="Powered by granite-embedding-reranker-english-r2", scale=1
)
example_passages = (
"Radioactivity was discovered in 1896 by Becquerel and independently by Marie Curie, "
"while working with phosphorescent materials.\n"
"Albert Einstein was a theoretical physicist known for his theory of relativity.\n"
"Marie Curie discovered the elements polonium and radium through her research.\n"
"The sky appears blue due to Rayleigh scattering of sunlight.\n"
"Isaac Newton formulated the laws of motion and gravitation.\n"
"A mitochondrion is an organelle responsible for energy production in cells.\n"
"Python is a popular programming language for data science.\n"
"La teoría de la relatividad cambió nuestra comprensión del universo.\n"
"Marie Curie fue la primera mujer en ganar un Premio Nobel.\n"
"El agua es esencial para la vida en la Tierra."
)
with gr.Row():
query = gr.Textbox(label="Query", value="Who discovered radioactive elements?", lines=10, scale=1)
passages = gr.Textbox(label="Passages (one per line)", lines=10, value=example_passages, scale=2)
with gr.Row():
retrieve_n = gr.Slider(1, 100, value=20, step=1, label="Retrieve top-N candidates (embeddings)")
top_k = gr.Slider(1, 20, value=5, step=1, label="Show top-K results (final)")
run_btn = gr.Button("Rank passages", variant="primary")
results = gr.Dataframe(
headers=["rank", "passage", "embed_score", "rerank_score"],
datatype=["number", "str", "number", "number"],
label="Ranked results",
wrap=False,
interactive=False,
column_widths=["8%", "65%", "13.5%", "13.5%"],
)
def update_reranker(model: str) -> dict[str, Any]:
"""Update reranker checkbox based on selected model."""
is_multilingual = "multilingual" in model
return gr.update(value=not is_multilingual, interactive=not is_multilingual)
model_name.change(
fn=update_reranker,
inputs=[model_name],
outputs=[use_reranker],
)
demo.load(
fn=update_reranker,
inputs=[model_name],
outputs=[use_reranker],
)
run_btn.click(
fn=rank_passages,
inputs=[query, passages, top_k, retrieve_n, use_reranker, model_name],
outputs=[results],
)
demo.launch()