reranker / app.py
xavier-fuentes's picture
Upload folder using huggingface_hub
9b27614 verified
import json
import time
from typing import List
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "Qwen/Qwen3-Reranker-8B"
INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query"
# Lazy load - model loaded on first GPU call to avoid CPU OOM
_model = None
_tokenizer = None
def _get_model():
global _model, _tokenizer
if _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left", trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
_model.eval()
return _model, _tokenizer
prefix = (
"<|im_start|>system\n"
"Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
'Note that the answer can only be "yes" or "no".'
"<|im_end|>\n<|im_start|>user\n"
)
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
def _format_pair(query: str, doc: str) -> str:
return f"<Instruct>: {INSTRUCTION}\n<Query>: {query}\n<Document>: {doc}"
@spaces.GPU
def rerank(query: str, passages_text: str, top_k: int):
start = time.perf_counter()
model, tokenizer = _get_model()
token_true_id = tokenizer.convert_tokens_to_ids("yes")
token_false_id = tokenizer.convert_tokens_to_ids("no")
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
max_length = 8192
query = (query or "").strip()
passages = [l.strip() for l in (passages_text or "").splitlines() if l.strip()]
if not query:
raise gr.Error("Please provide a query.")
if not passages:
raise gr.Error("Please provide at least one passage (one per line).")
top_k = max(1, min(int(top_k), 50, len(passages)))
pairs = [_format_pair(query, p) for p in passages]
inputs = tokenizer(
pairs,
padding=False,
truncation="longest_first",
return_attention_mask=False,
max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
)
for i, ids in enumerate(inputs["input_ids"]):
inputs["input_ids"][i] = prefix_tokens + ids + suffix_tokens
inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
with torch.no_grad():
logits = model(**inputs).logits[:, -1, :]
true_vec = logits[:, token_true_id]
false_vec = logits[:, token_false_id]
scores_2way = torch.stack([false_vec, true_vec], dim=1)
scores_2way = torch.nn.functional.log_softmax(scores_2way, dim=1)
scores = scores_2way[:, 1].exp().tolist()
results = []
for idx, (passage, score) in enumerate(zip(passages, scores)):
results.append({"index": idx, "score": round(float(score), 6), "text": passage[:500]})
results.sort(key=lambda x: x["score"], reverse=True)
results = results[:top_k]
elapsed = time.perf_counter() - start
lines = ["| Rank | Score | Passage |", "|---:|---:|---|"]
for i, r in enumerate(results, 1):
safe = r["text"][:120].replace("|", "\\|").replace("\n", " ")
lines.append(f"| {i} | {r['score']:.4f} | {safe} |")
json_out = json.dumps(results, ensure_ascii=False, indent=2)
summary = f"Reranked {len(passages)} passages in {elapsed:.2f}s. Model: Qwen3-Reranker-8B"
return "\n".join(lines), json_out, summary
with gr.Blocks(theme=gr.themes.Soft(), title="Text Reranker") as demo:
gr.Markdown("# πŸ“„ Qwen3-Reranker-8B Text Reranker")
gr.Markdown(
"Rerank text passages by relevance to your query. "
"Uses the yes/no logit scoring approach from Qwen3-Reranker-8B (MTEB-R English: 69.02)."
)
query = gr.Textbox(label="Query", placeholder="Enter your search query...", lines=2)
passages = gr.Textbox(
label="Passages (one per line)",
placeholder="Paste passages here, one per line...",
lines=10,
)
top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Top-K")
run_btn = gr.Button("Rerank", variant="primary")
output_md = gr.Markdown(label="Ranked Results")
output_json = gr.Code(label="JSON Output", language="json")
timing = gr.Textbox(label="Info", interactive=False)
run_btn.click(
fn=rerank,
inputs=[query, passages, top_k],
outputs=[output_md, output_json, timing],
api_name="rerank_api",
)
gr.Markdown(
"Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ "
"[AI Enablement Academy](https://enablement.academy) | "
"[Buy me a coffee β˜•](https://ko-fi.com/xavierfuentes)"
)
if __name__ == "__main__":
demo.queue().launch()