| | """ |
| | Custom handler for MonoT5 reranking on HuggingFace Inference Endpoints. |
| | |
| | Returns relevance probability scores for query-document pairs. |
| | """ |
| |
|
| | import math |
| | from typing import Any, Dict, List |
| |
|
| | import torch |
| | from transformers import T5ForConditionalGeneration, T5Tokenizer |
| |
|
| |
|
| | class EndpointHandler: |
| | """Handler for MonoT5 relevance scoring.""" |
| | |
| | def __init__(self, path: str = ""): |
| | """Initialize the model and tokenizer.""" |
| | self.tokenizer = T5Tokenizer.from_pretrained(path) |
| | self.model = T5ForConditionalGeneration.from_pretrained(path) |
| | self.model.eval() |
| | |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model = self.model.to(self.device) |
| | |
| | |
| | self.true_id = self.tokenizer.encode("true", add_special_tokens=False)[0] |
| | self.false_id = self.tokenizer.encode("false", add_special_tokens=False)[0] |
| | |
| | print(f"MonoT5 loaded on {self.device}") |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Process inference requests. |
| | |
| | Accepts either: |
| | - {"inputs": "Query: ... Document: ... Relevant:"} - single input |
| | - {"inputs": ["Query: ... Document: ... Relevant:", ...]} - batch |
| | - {"query": "...", "documents": ["...", ...]} - structured input |
| | |
| | Returns: |
| | - List of {"score": float, "label": "true"/"false"} dicts |
| | """ |
| | inputs = data.get("inputs", data) |
| | |
| | |
| | if "query" in data and "documents" in data: |
| | query = data["query"] |
| | documents = data["documents"] |
| | inputs = [ |
| | f"Query: {query} Document: {doc} Relevant:" |
| | for doc in documents |
| | ] |
| | |
| | |
| | if isinstance(inputs, str): |
| | inputs = [inputs] |
| | |
| | |
| | results = [] |
| | for input_text in inputs: |
| | score = self._score_single(input_text) |
| | results.append({ |
| | "score": score, |
| | "label": "true" if score > 0.5 else "false" |
| | }) |
| | |
| | return results |
| | |
| | def _score_single(self, input_text: str) -> float: |
| | """Score a single query-document pair.""" |
| | |
| | inputs = self.tokenizer( |
| | input_text, |
| | return_tensors="pt", |
| | max_length=512, |
| | truncation=True, |
| | padding=True |
| | ).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | decoder_input_ids = torch.tensor( |
| | [[self.tokenizer.pad_token_id]], |
| | device=self.device |
| | ) |
| | outputs = self.model( |
| | **inputs, |
| | decoder_input_ids=decoder_input_ids |
| | ) |
| | logits = outputs.logits[0, -1, :] |
| | |
| | |
| | true_logit = logits[self.true_id].item() |
| | false_logit = logits[self.false_id].item() |
| | |
| | |
| | max_logit = max(true_logit, false_logit) |
| | true_prob = math.exp(true_logit - max_logit) |
| | false_prob = math.exp(false_logit - max_logit) |
| | |
| | return true_prob / (true_prob + false_prob) |
| |
|