Arastun
/

Qwen3-Reranker-4B / handler.py
Arastun's picture
Update handler.py
83ceda7 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto"
)
self.model.eval()
# Qwen3-Reranker uses a specific token to extract the relevance score
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
def format_input(self, query, document, instruction=None):
if instruction is None:
instruction = "Given a web search query, retrieve relevant passages that answer the query"
prefix = f"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct. Note only output a single token from [yes, no] after thinking.\n<|im_end|>\n<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n"
return prefix
def __call__(self, data: dict) -> dict:
"""
Expected input format:
{
"query": "what is the capital of France",
"documents": ["Paris is the capital...", "London is the capital..."],
"instruction": "optional custom instruction"
}
"""
inputs = data.get("inputs", data) # unwrap HF gateway nesting
query = inputs.get("query", "")
documents = inputs.get("documents", [])
instruction = inputs.get("instruction", None)
if not query or not documents:
return {"error": "Must provide 'query' and 'documents'"}
prompts = [self.format_input(query, doc, instruction) for doc in documents]
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096
).to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Get logits for the final token position
logits = outputs.logits[:, -1, :]
# Score is the softmax probability of "yes" vs "no"
true_logits = logits[:, self.token_true_id]
false_logits = logits[:, self.token_false_id]
scores = torch.softmax(
torch.stack([false_logits, true_logits], dim=1), dim=1
)[:, 1].tolist()
return {
"scores": scores,
"ranking": sorted(
range(len(documents)),
key=lambda i: scores[i],
reverse=True
)
}