Spaces:
Sleeping
Sleeping
| import json | |
| import time | |
| from typing import List | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_NAME = "Qwen/Qwen3-Reranker-8B" | |
| INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query" | |
| # Lazy load - model loaded on first GPU call to avoid CPU OOM | |
| _model = None | |
| _tokenizer = None | |
| def _get_model(): | |
| global _model, _tokenizer | |
| if _model is None: | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left", trust_remote_code=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| _model.eval() | |
| return _model, _tokenizer | |
| prefix = ( | |
| "<|im_start|>system\n" | |
| "Judge whether the Document meets the requirements based on the Query and the Instruct provided. " | |
| 'Note that the answer can only be "yes" or "no".' | |
| "<|im_end|>\n<|im_start|>user\n" | |
| ) | |
| suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| def _format_pair(query: str, doc: str) -> str: | |
| return f"<Instruct>: {INSTRUCTION}\n<Query>: {query}\n<Document>: {doc}" | |
| def rerank(query: str, passages_text: str, top_k: int): | |
| start = time.perf_counter() | |
| model, tokenizer = _get_model() | |
| token_true_id = tokenizer.convert_tokens_to_ids("yes") | |
| token_false_id = tokenizer.convert_tokens_to_ids("no") | |
| prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) | |
| suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False) | |
| max_length = 8192 | |
| query = (query or "").strip() | |
| passages = [l.strip() for l in (passages_text or "").splitlines() if l.strip()] | |
| if not query: | |
| raise gr.Error("Please provide a query.") | |
| if not passages: | |
| raise gr.Error("Please provide at least one passage (one per line).") | |
| top_k = max(1, min(int(top_k), 50, len(passages))) | |
| pairs = [_format_pair(query, p) for p in passages] | |
| inputs = tokenizer( | |
| pairs, | |
| padding=False, | |
| truncation="longest_first", | |
| return_attention_mask=False, | |
| max_length=max_length - len(prefix_tokens) - len(suffix_tokens), | |
| ) | |
| for i, ids in enumerate(inputs["input_ids"]): | |
| inputs["input_ids"][i] = prefix_tokens + ids + suffix_tokens | |
| inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length) | |
| for key in inputs: | |
| inputs[key] = inputs[key].to(model.device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits[:, -1, :] | |
| true_vec = logits[:, token_true_id] | |
| false_vec = logits[:, token_false_id] | |
| scores_2way = torch.stack([false_vec, true_vec], dim=1) | |
| scores_2way = torch.nn.functional.log_softmax(scores_2way, dim=1) | |
| scores = scores_2way[:, 1].exp().tolist() | |
| results = [] | |
| for idx, (passage, score) in enumerate(zip(passages, scores)): | |
| results.append({"index": idx, "score": round(float(score), 6), "text": passage[:500]}) | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| results = results[:top_k] | |
| elapsed = time.perf_counter() - start | |
| lines = ["| Rank | Score | Passage |", "|---:|---:|---|"] | |
| for i, r in enumerate(results, 1): | |
| safe = r["text"][:120].replace("|", "\\|").replace("\n", " ") | |
| lines.append(f"| {i} | {r['score']:.4f} | {safe} |") | |
| json_out = json.dumps(results, ensure_ascii=False, indent=2) | |
| summary = f"Reranked {len(passages)} passages in {elapsed:.2f}s. Model: Qwen3-Reranker-8B" | |
| return "\n".join(lines), json_out, summary | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Text Reranker") as demo: | |
| gr.Markdown("# π Qwen3-Reranker-8B Text Reranker") | |
| gr.Markdown( | |
| "Rerank text passages by relevance to your query. " | |
| "Uses the yes/no logit scoring approach from Qwen3-Reranker-8B (MTEB-R English: 69.02)." | |
| ) | |
| query = gr.Textbox(label="Query", placeholder="Enter your search query...", lines=2) | |
| passages = gr.Textbox( | |
| label="Passages (one per line)", | |
| placeholder="Paste passages here, one per line...", | |
| lines=10, | |
| ) | |
| top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Top-K") | |
| run_btn = gr.Button("Rerank", variant="primary") | |
| output_md = gr.Markdown(label="Ranked Results") | |
| output_json = gr.Code(label="JSON Output", language="json") | |
| timing = gr.Textbox(label="Info", interactive=False) | |
| run_btn.click( | |
| fn=rerank, | |
| inputs=[query, passages, top_k], | |
| outputs=[output_md, output_json, timing], | |
| api_name="rerank_api", | |
| ) | |
| gr.Markdown( | |
| "Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ " | |
| "[AI Enablement Academy](https://enablement.academy) | " | |
| "[Buy me a coffee β](https://ko-fi.com/xavierfuentes)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |