import json import time from typing import List import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_NAME = "Qwen/Qwen3-Reranker-8B" INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query" # Lazy load - model loaded on first GPU call to avoid CPU OOM _model = None _tokenizer = None def _get_model(): global _model, _tokenizer if _model is None: _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left", trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) _model.eval() return _model, _tokenizer prefix = ( "<|im_start|>system\n" "Judge whether the Document meets the requirements based on the Query and the Instruct provided. " 'Note that the answer can only be "yes" or "no".' "<|im_end|>\n<|im_start|>user\n" ) suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" def _format_pair(query: str, doc: str) -> str: return f": {INSTRUCTION}\n: {query}\n: {doc}" @spaces.GPU def rerank(query: str, passages_text: str, top_k: int): start = time.perf_counter() model, tokenizer = _get_model() token_true_id = tokenizer.convert_tokens_to_ids("yes") token_false_id = tokenizer.convert_tokens_to_ids("no") prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) max_length = 8192 query = (query or "").strip() passages = [l.strip() for l in (passages_text or "").splitlines() if l.strip()] if not query: raise gr.Error("Please provide a query.") if not passages: raise gr.Error("Please provide at least one passage (one per line).") top_k = max(1, min(int(top_k), 50, len(passages))) pairs = [_format_pair(query, p) for p in passages] inputs = tokenizer( pairs, padding=False, truncation="longest_first", return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens), ) for i, ids in enumerate(inputs["input_ids"]): inputs["input_ids"][i] = prefix_tokens + ids + suffix_tokens inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) for key in inputs: inputs[key] = inputs[key].to(model.device) with torch.no_grad(): logits = model(**inputs).logits[:, -1, :] true_vec = logits[:, token_true_id] false_vec = logits[:, token_false_id] scores_2way = torch.stack([false_vec, true_vec], dim=1) scores_2way = torch.nn.functional.log_softmax(scores_2way, dim=1) scores = scores_2way[:, 1].exp().tolist() results = [] for idx, (passage, score) in enumerate(zip(passages, scores)): results.append({"index": idx, "score": round(float(score), 6), "text": passage[:500]}) results.sort(key=lambda x: x["score"], reverse=True) results = results[:top_k] elapsed = time.perf_counter() - start lines = ["| Rank | Score | Passage |", "|---:|---:|---|"] for i, r in enumerate(results, 1): safe = r["text"][:120].replace("|", "\\|").replace("\n", " ") lines.append(f"| {i} | {r['score']:.4f} | {safe} |") json_out = json.dumps(results, ensure_ascii=False, indent=2) summary = f"Reranked {len(passages)} passages in {elapsed:.2f}s. Model: Qwen3-Reranker-8B" return "\n".join(lines), json_out, summary with gr.Blocks(theme=gr.themes.Soft(), title="Text Reranker") as demo: gr.Markdown("# 📄 Qwen3-Reranker-8B Text Reranker") gr.Markdown( "Rerank text passages by relevance to your query. " "Uses the yes/no logit scoring approach from Qwen3-Reranker-8B (MTEB-R English: 69.02)." ) query = gr.Textbox(label="Query", placeholder="Enter your search query...", lines=2) passages = gr.Textbox( label="Passages (one per line)", placeholder="Paste passages here, one per line...", lines=10, ) top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Top-K") run_btn = gr.Button("Rerank", variant="primary") output_md = gr.Markdown(label="Ranked Results") output_json = gr.Code(label="JSON Output", language="json") timing = gr.Textbox(label="Info", interactive=False) run_btn.click( fn=rerank, inputs=[query, passages, top_k], outputs=[output_md, output_json, timing], api_name="rerank_api", ) gr.Markdown( "Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ " "[AI Enablement Academy](https://enablement.academy) | " "[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)" ) if __name__ == "__main__": demo.queue().launch()