""" Custom handler for MonoT5 reranking on HuggingFace Inference Endpoints. Returns relevance probability scores for query-document pairs. """ import math from typing import Any, Dict, List import torch from transformers import T5ForConditionalGeneration, T5Tokenizer class EndpointHandler: """Handler for MonoT5 relevance scoring.""" def __init__(self, path: str = ""): """Initialize the model and tokenizer.""" self.tokenizer = T5Tokenizer.from_pretrained(path) self.model = T5ForConditionalGeneration.from_pretrained(path) self.model.eval() # Move to GPU if available self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) # Get token IDs for "true" and "false" self.true_id = self.tokenizer.encode("true", add_special_tokens=False)[0] self.false_id = self.tokenizer.encode("false", add_special_tokens=False)[0] print(f"MonoT5 loaded on {self.device}") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference requests. Accepts either: - {"inputs": "Query: ... Document: ... Relevant:"} - single input - {"inputs": ["Query: ... Document: ... Relevant:", ...]} - batch - {"query": "...", "documents": ["...", ...]} - structured input Returns: - List of {"score": float, "label": "true"/"false"} dicts """ inputs = data.get("inputs", data) # Handle structured input format if "query" in data and "documents" in data: query = data["query"] documents = data["documents"] inputs = [ f"Query: {query} Document: {doc} Relevant:" for doc in documents ] # Ensure inputs is a list if isinstance(inputs, str): inputs = [inputs] # Score all inputs results = [] for input_text in inputs: score = self._score_single(input_text) results.append({ "score": score, "label": "true" if score > 0.5 else "false" }) return results def _score_single(self, input_text: str) -> float: """Score a single query-document pair.""" # Tokenize inputs = self.tokenizer( input_text, return_tensors="pt", max_length=512, truncation=True, padding=True ).to(self.device) # Get logits for first generated token with torch.no_grad(): decoder_input_ids = torch.tensor( [[self.tokenizer.pad_token_id]], device=self.device ) outputs = self.model( **inputs, decoder_input_ids=decoder_input_ids ) logits = outputs.logits[0, -1, :] # Get probabilities for true/false tokens true_logit = logits[self.true_id].item() false_logit = logits[self.false_id].item() # Softmax over true/false max_logit = max(true_logit, false_logit) true_prob = math.exp(true_logit - max_logit) false_prob = math.exp(false_logit - max_logit) return true_prob / (true_prob + false_prob)