from transformers import AutoTokenizer, AutoModelForCausalLM import torch class EndpointHandler: def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left') self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16, device_map="auto" ) self.model.eval() # Qwen3-Reranker uses a specific token to extract the relevance score self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") def format_input(self, query, document, instruction=None): if instruction is None: instruction = "Given a web search query, retrieve relevant passages that answer the query" prefix = f"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct. Note only output a single token from [yes, no] after thinking.\n<|im_end|>\n<|im_start|>user\n: {instruction}\n: {query}\n: {document}\n<|im_end|>\n<|im_start|>assistant\n\n\n\n" return prefix def __call__(self, data: dict) -> dict: """ Expected input format: { "query": "what is the capital of France", "documents": ["Paris is the capital...", "London is the capital..."], "instruction": "optional custom instruction" } """ inputs = data.get("inputs", data) # unwrap HF gateway nesting query = inputs.get("query", "") documents = inputs.get("documents", []) instruction = inputs.get("instruction", None) if not query or not documents: return {"error": "Must provide 'query' and 'documents'"} prompts = [self.format_input(query, doc, instruction) for doc in documents] inputs = self.tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=4096 ).to(self.model.device) with torch.no_grad(): outputs = self.model(**inputs) # Get logits for the final token position logits = outputs.logits[:, -1, :] # Score is the softmax probability of "yes" vs "no" true_logits = logits[:, self.token_true_id] false_logits = logits[:, self.token_false_id] scores = torch.softmax( torch.stack([false_logits, true_logits], dim=1), dim=1 )[:, 1].tolist() return { "scores": scores, "ranking": sorted( range(len(documents)), key=lambda i: scores[i], reverse=True ) }