"""Interactive demo for IBM Granite Embedding R2 models.""" from typing import Any import gradio as gr import numpy as np import spaces from sentence_transformers import CrossEncoder, SentenceTransformer MODELS = { "granite-embedding-small-english-r2": "ibm-granite/granite-embedding-small-english-r2", "granite-embedding-97m-multilingual-r2": "ibm-granite/granite-embedding-97m-multilingual-r2", } MODEL_IDS = list(MODELS.values()) RERANK_MODEL_ID = "ibm-granite/granite-embedding-reranker-english-r2" embedder = None current_embed_model = None reranker = None def get_embedder(model_id: str) -> Any: # noqa: ANN401 """Load or retrieve the embedding model.""" global embedder, current_embed_model if embedder is None or current_embed_model != model_id: embedder = SentenceTransformer(model_id) current_embed_model = model_id return embedder def get_reranker() -> CrossEncoder: """Load or retrieve the reranker model.""" global reranker if reranker is None: reranker = CrossEncoder(RERANK_MODEL_ID) return reranker def cos_sim_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray: """Compute cosine similarity matrix between two sets of embeddings.""" return a @ b.T @spaces.GPU def rank_passages( query: str, passages_text: str, top_k: int, retrieve_n: int, use_reranker: bool, model_name: str, ) -> list[list[Any]]: """Rank passages by similarity to query using embeddings and optional reranking.""" query = (query or "").strip() passages = [p.strip() for p in (passages_text or "").splitlines() if p.strip()] if not query: return [] if not passages: return [] embedder_model = get_embedder(model_name) q_emb = embedder_model.encode([query], normalize_embeddings=True) p_emb = embedder_model.encode(passages, normalize_embeddings=True) sims = cos_sim_matrix(q_emb, p_emb)[0] order = np.argsort(-sims) retrieve_n = int(max(1, min(retrieve_n, len(passages)))) cand_idx = order[:retrieve_n].tolist() rerank_scores_for_display = dict.fromkeys(cand_idx, None) if use_reranker: pairs = [(query, passages[i]) for i in cand_idx] rr_scores = get_reranker().predict(pairs) rr_order = np.argsort(-rr_scores) cand_idx = [cand_idx[j] for j in rr_order] for j, i in enumerate(cand_idx): rerank_scores_for_display[i] = float(rr_scores[rr_order[j]]) top_k = int(max(1, min(top_k, len(cand_idx)))) rows = [] for rank, idx in enumerate(cand_idx[:top_k], start=1): rows.append( [ rank, passages[idx], float(sims[idx]), rerank_scores_for_display.get(idx), ] ) return rows theme = gr.themes.Base(primary_hue=gr.themes.colors.blue) with gr.Blocks(title="Granite Embedding Demo", theme=theme) as demo: gr.HTML( "

Granite Embedding Demo

" # noqa: E501 ) gr.HTML( "
" "📚 Model Collection" # noqa: E501 "🔗 Repository" # noqa: E501 "📄 Paper" # noqa: E501 "
" ) with gr.Accordion("ℹ️ Learn more", open=False): # noqa: RUF001 gr.Markdown( "**What are embedding models?**\n\n" "Embedding models convert words and sentences into numeric " "representations, letting us measure meaning rather than just matching text. Search engines use " "embeddings to show you results about the same topic, even when the exact words differ.\n\n" "**What are reranker models?**\n\n" "Rerankers apply a second, more fine-tuned scoring step to re-rank " "the top results from embeddings. They don't mean embeddings are inaccurate—they just add another layer " "of precision to surface the best answers.\n\n" "**What are top-N candidates?**\n\n" '"Top-N candidates" are the initial set of results retrieved using ' "the embedding model. Out of a large collection, we quickly pick the N most similar items to your " "query as likely matches. Choosing a larger N means we consider more possibilities, which can improve " "accuracy (especially when using a re-ranker), but it also makes things slower. A smaller N is faster, " "but risks missing some relevant results.\n\n" "**What are top-K results?**\n\n" '"Top-K results" are the final results we return to the user. After ' "optionally re-ranking the top-N candidates, we select the best K items to display. Choosing a smaller " "K keeps results focused and easy to explore, while a larger K shows more options but may include less " "relevant items. In short: N controls how much we consider, K controls how much you see." ) with gr.Group(): with gr.Row(): model_name = gr.Dropdown( choices=MODEL_IDS, value="ibm-granite/granite-embedding-small-english-r2", label="Embedding Model", scale=2, ) use_reranker = gr.Checkbox( value=True, label="Enable reranking", info="Powered by granite-embedding-reranker-english-r2", scale=1 ) example_passages = ( "Radioactivity was discovered in 1896 by Becquerel and independently by Marie Curie, " "while working with phosphorescent materials.\n" "Albert Einstein was a theoretical physicist known for his theory of relativity.\n" "Marie Curie discovered the elements polonium and radium through her research.\n" "The sky appears blue due to Rayleigh scattering of sunlight.\n" "Isaac Newton formulated the laws of motion and gravitation.\n" "A mitochondrion is an organelle responsible for energy production in cells.\n" "Python is a popular programming language for data science.\n" "La teoría de la relatividad cambió nuestra comprensión del universo.\n" "Marie Curie fue la primera mujer en ganar un Premio Nobel.\n" "El agua es esencial para la vida en la Tierra." ) with gr.Row(): query = gr.Textbox(label="Query", value="Who discovered radioactive elements?", lines=10, scale=1) passages = gr.Textbox(label="Passages (one per line)", lines=10, value=example_passages, scale=2) with gr.Row(): retrieve_n = gr.Slider(1, 100, value=20, step=1, label="Retrieve top-N candidates (embeddings)") top_k = gr.Slider(1, 20, value=5, step=1, label="Show top-K results (final)") run_btn = gr.Button("Rank passages", variant="primary") results = gr.Dataframe( headers=["rank", "passage", "embed_score", "rerank_score"], datatype=["number", "str", "number", "number"], label="Ranked results", wrap=False, interactive=False, column_widths=["8%", "65%", "13.5%", "13.5%"], ) def update_reranker(model: str) -> dict[str, Any]: """Update reranker checkbox based on selected model.""" is_multilingual = "multilingual" in model return gr.update(value=not is_multilingual, interactive=not is_multilingual) model_name.change( fn=update_reranker, inputs=[model_name], outputs=[use_reranker], ) demo.load( fn=update_reranker, inputs=[model_name], outputs=[use_reranker], ) run_btn.click( fn=rank_passages, inputs=[query, passages, top_k, retrieve_n, use_reranker, model_name], outputs=[results], ) demo.launch()