monot5-base-msmarco / handler.py
pashaa's picture
Upload folder using huggingface_hub
0518d49 verified
"""
Custom handler for MonoT5 reranking on HuggingFace Inference Endpoints.
Returns relevance probability scores for query-document pairs.
"""
import math
from typing import Any, Dict, List
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
class EndpointHandler:
"""Handler for MonoT5 relevance scoring."""
def __init__(self, path: str = ""):
"""Initialize the model and tokenizer."""
self.tokenizer = T5Tokenizer.from_pretrained(path)
self.model = T5ForConditionalGeneration.from_pretrained(path)
self.model.eval()
# Move to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
# Get token IDs for "true" and "false"
self.true_id = self.tokenizer.encode("true", add_special_tokens=False)[0]
self.false_id = self.tokenizer.encode("false", add_special_tokens=False)[0]
print(f"MonoT5 loaded on {self.device}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests.
Accepts either:
- {"inputs": "Query: ... Document: ... Relevant:"} - single input
- {"inputs": ["Query: ... Document: ... Relevant:", ...]} - batch
- {"query": "...", "documents": ["...", ...]} - structured input
Returns:
- List of {"score": float, "label": "true"/"false"} dicts
"""
inputs = data.get("inputs", data)
# Handle structured input format
if "query" in data and "documents" in data:
query = data["query"]
documents = data["documents"]
inputs = [
f"Query: {query} Document: {doc} Relevant:"
for doc in documents
]
# Ensure inputs is a list
if isinstance(inputs, str):
inputs = [inputs]
# Score all inputs
results = []
for input_text in inputs:
score = self._score_single(input_text)
results.append({
"score": score,
"label": "true" if score > 0.5 else "false"
})
return results
def _score_single(self, input_text: str) -> float:
"""Score a single query-document pair."""
# Tokenize
inputs = self.tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
).to(self.device)
# Get logits for first generated token
with torch.no_grad():
decoder_input_ids = torch.tensor(
[[self.tokenizer.pad_token_id]],
device=self.device
)
outputs = self.model(
**inputs,
decoder_input_ids=decoder_input_ids
)
logits = outputs.logits[0, -1, :]
# Get probabilities for true/false tokens
true_logit = logits[self.true_id].item()
false_logit = logits[self.false_id].item()
# Softmax over true/false
max_logit = max(true_logit, false_logit)
true_prob = math.exp(true_logit - max_logit)
false_prob = math.exp(false_logit - max_logit)
return true_prob / (true_prob + false_prob)