Spaces:
Running on Zero
Running on Zero
| """Interactive demo for IBM Granite Embedding R2 models.""" | |
| from typing import Any | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| MODELS = { | |
| "granite-embedding-small-english-r2": "ibm-granite/granite-embedding-small-english-r2", | |
| "granite-embedding-97m-multilingual-r2": "ibm-granite/granite-embedding-97m-multilingual-r2", | |
| } | |
| MODEL_IDS = list(MODELS.values()) | |
| RERANK_MODEL_ID = "ibm-granite/granite-embedding-reranker-english-r2" | |
| embedder = None | |
| current_embed_model = None | |
| reranker = None | |
| def get_embedder(model_id: str) -> Any: # noqa: ANN401 | |
| """Load or retrieve the embedding model.""" | |
| global embedder, current_embed_model | |
| if embedder is None or current_embed_model != model_id: | |
| embedder = SentenceTransformer(model_id) | |
| current_embed_model = model_id | |
| return embedder | |
| def get_reranker() -> CrossEncoder: | |
| """Load or retrieve the reranker model.""" | |
| global reranker | |
| if reranker is None: | |
| reranker = CrossEncoder(RERANK_MODEL_ID) | |
| return reranker | |
| def cos_sim_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray: | |
| """Compute cosine similarity matrix between two sets of embeddings.""" | |
| return a @ b.T | |
| def rank_passages( | |
| query: str, | |
| passages_text: str, | |
| top_k: int, | |
| retrieve_n: int, | |
| use_reranker: bool, | |
| model_name: str, | |
| ) -> list[list[Any]]: | |
| """Rank passages by similarity to query using embeddings and optional reranking.""" | |
| query = (query or "").strip() | |
| passages = [p.strip() for p in (passages_text or "").splitlines() if p.strip()] | |
| if not query: | |
| return [] | |
| if not passages: | |
| return [] | |
| embedder_model = get_embedder(model_name) | |
| q_emb = embedder_model.encode([query], normalize_embeddings=True) | |
| p_emb = embedder_model.encode(passages, normalize_embeddings=True) | |
| sims = cos_sim_matrix(q_emb, p_emb)[0] | |
| order = np.argsort(-sims) | |
| retrieve_n = int(max(1, min(retrieve_n, len(passages)))) | |
| cand_idx = order[:retrieve_n].tolist() | |
| rerank_scores_for_display = dict.fromkeys(cand_idx, None) | |
| if use_reranker: | |
| pairs = [(query, passages[i]) for i in cand_idx] | |
| rr_scores = get_reranker().predict(pairs) | |
| rr_order = np.argsort(-rr_scores) | |
| cand_idx = [cand_idx[j] for j in rr_order] | |
| for j, i in enumerate(cand_idx): | |
| rerank_scores_for_display[i] = float(rr_scores[rr_order[j]]) | |
| top_k = int(max(1, min(top_k, len(cand_idx)))) | |
| rows = [] | |
| for rank, idx in enumerate(cand_idx[:top_k], start=1): | |
| rows.append( | |
| [ | |
| rank, | |
| passages[idx], | |
| float(sims[idx]), | |
| rerank_scores_for_display.get(idx), | |
| ] | |
| ) | |
| return rows | |
| theme = gr.themes.Base(primary_hue=gr.themes.colors.blue) | |
| with gr.Blocks(title="Granite Embedding Demo", theme=theme) as demo: | |
| gr.HTML( | |
| "<div style='text-align: center; margin-top: 5px; margin-bottom: 0;'><h1 style='font-size: 48px; margin: 0;'>Granite Embedding Demo</h1></div>" # noqa: E501 | |
| ) | |
| gr.HTML( | |
| "<div style='display: flex; justify-content: center; gap: 16px; margin: -5px 0 15px 0;'>" | |
| "<a href='https://huggingface.co/collections/ibm-granite/granite-embedding' target='_blank' style='padding: 4px 8px; border: 1px solid #e0e0e0; border-radius: 4px; text-decoration: none; color: inherit; font-size: 14px;' onmouseover=\"this.style.backgroundColor='#f0f0f0'\" onmouseout=\"this.style.backgroundColor='transparent'\">📚 Model Collection</a>" # noqa: E501 | |
| "<a href='https://github.com/ibm-granite/granite-embedding-models' target='_blank' style='padding: 4px 8px; border: 1px solid #e0e0e0; border-radius: 4px; text-decoration: none; color: inherit; font-size: 14px;' onmouseover=\"this.style.backgroundColor='#f0f0f0'\" onmouseout=\"this.style.backgroundColor='transparent'\">🔗 Repository</a>" # noqa: E501 | |
| "<a href='https://arxiv.org/abs/2508.21085' target='_blank' style='padding: 4px 8px; border: 1px solid #e0e0e0; border-radius: 4px; text-decoration: none; color: inherit; font-size: 14px;' onmouseover=\"this.style.backgroundColor='#f0f0f0'\" onmouseout=\"this.style.backgroundColor='transparent'\">📄 Paper</a>" # noqa: E501 | |
| "</div>" | |
| ) | |
| with gr.Accordion("ℹ️ Learn more", open=False): # noqa: RUF001 | |
| gr.Markdown( | |
| "**What are embedding models?**\n\n" | |
| "Embedding models convert words and sentences into numeric " | |
| "representations, letting us measure meaning rather than just matching text. Search engines use " | |
| "embeddings to show you results about the same topic, even when the exact words differ.\n\n" | |
| "**What are reranker models?**\n\n" | |
| "Rerankers apply a second, more fine-tuned scoring step to re-rank " | |
| "the top results from embeddings. They don't mean embeddings are inaccurate—they just add another layer " | |
| "of precision to surface the best answers.\n\n" | |
| "**What are top-N candidates?**\n\n" | |
| '"Top-N candidates" are the initial set of results retrieved using ' | |
| "the embedding model. Out of a large collection, we quickly pick the N most similar items to your " | |
| "query as likely matches. Choosing a larger N means we consider more possibilities, which can improve " | |
| "accuracy (especially when using a re-ranker), but it also makes things slower. A smaller N is faster, " | |
| "but risks missing some relevant results.\n\n" | |
| "**What are top-K results?**\n\n" | |
| '"Top-K results" are the final results we return to the user. After ' | |
| "optionally re-ranking the top-N candidates, we select the best K items to display. Choosing a smaller " | |
| "K keeps results focused and easy to explore, while a larger K shows more options but may include less " | |
| "relevant items. In short: N controls how much we consider, K controls how much you see." | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| choices=MODEL_IDS, | |
| value="ibm-granite/granite-embedding-small-english-r2", | |
| label="Embedding Model", | |
| scale=2, | |
| ) | |
| use_reranker = gr.Checkbox( | |
| value=True, label="Enable reranking", info="Powered by granite-embedding-reranker-english-r2", scale=1 | |
| ) | |
| example_passages = ( | |
| "Radioactivity was discovered in 1896 by Becquerel and independently by Marie Curie, " | |
| "while working with phosphorescent materials.\n" | |
| "Albert Einstein was a theoretical physicist known for his theory of relativity.\n" | |
| "Marie Curie discovered the elements polonium and radium through her research.\n" | |
| "The sky appears blue due to Rayleigh scattering of sunlight.\n" | |
| "Isaac Newton formulated the laws of motion and gravitation.\n" | |
| "A mitochondrion is an organelle responsible for energy production in cells.\n" | |
| "Python is a popular programming language for data science.\n" | |
| "La teoría de la relatividad cambió nuestra comprensión del universo.\n" | |
| "Marie Curie fue la primera mujer en ganar un Premio Nobel.\n" | |
| "El agua es esencial para la vida en la Tierra." | |
| ) | |
| with gr.Row(): | |
| query = gr.Textbox(label="Query", value="Who discovered radioactive elements?", lines=10, scale=1) | |
| passages = gr.Textbox(label="Passages (one per line)", lines=10, value=example_passages, scale=2) | |
| with gr.Row(): | |
| retrieve_n = gr.Slider(1, 100, value=20, step=1, label="Retrieve top-N candidates (embeddings)") | |
| top_k = gr.Slider(1, 20, value=5, step=1, label="Show top-K results (final)") | |
| run_btn = gr.Button("Rank passages", variant="primary") | |
| results = gr.Dataframe( | |
| headers=["rank", "passage", "embed_score", "rerank_score"], | |
| datatype=["number", "str", "number", "number"], | |
| label="Ranked results", | |
| wrap=False, | |
| interactive=False, | |
| column_widths=["8%", "65%", "13.5%", "13.5%"], | |
| ) | |
| def update_reranker(model: str) -> dict[str, Any]: | |
| """Update reranker checkbox based on selected model.""" | |
| is_multilingual = "multilingual" in model | |
| return gr.update(value=not is_multilingual, interactive=not is_multilingual) | |
| model_name.change( | |
| fn=update_reranker, | |
| inputs=[model_name], | |
| outputs=[use_reranker], | |
| ) | |
| demo.load( | |
| fn=update_reranker, | |
| inputs=[model_name], | |
| outputs=[use_reranker], | |
| ) | |
| run_btn.click( | |
| fn=rank_passages, | |
| inputs=[query, passages, top_k, retrieve_n, use_reranker, model_name], | |
| outputs=[results], | |
| ) | |
| demo.launch() | |