File size: 3,453 Bytes
0518d49 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | """
Custom handler for MonoT5 reranking on HuggingFace Inference Endpoints.
Returns relevance probability scores for query-document pairs.
"""
import math
from typing import Any, Dict, List
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
class EndpointHandler:
"""Handler for MonoT5 relevance scoring."""
def __init__(self, path: str = ""):
"""Initialize the model and tokenizer."""
self.tokenizer = T5Tokenizer.from_pretrained(path)
self.model = T5ForConditionalGeneration.from_pretrained(path)
self.model.eval()
# Move to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
# Get token IDs for "true" and "false"
self.true_id = self.tokenizer.encode("true", add_special_tokens=False)[0]
self.false_id = self.tokenizer.encode("false", add_special_tokens=False)[0]
print(f"MonoT5 loaded on {self.device}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests.
Accepts either:
- {"inputs": "Query: ... Document: ... Relevant:"} - single input
- {"inputs": ["Query: ... Document: ... Relevant:", ...]} - batch
- {"query": "...", "documents": ["...", ...]} - structured input
Returns:
- List of {"score": float, "label": "true"/"false"} dicts
"""
inputs = data.get("inputs", data)
# Handle structured input format
if "query" in data and "documents" in data:
query = data["query"]
documents = data["documents"]
inputs = [
f"Query: {query} Document: {doc} Relevant:"
for doc in documents
]
# Ensure inputs is a list
if isinstance(inputs, str):
inputs = [inputs]
# Score all inputs
results = []
for input_text in inputs:
score = self._score_single(input_text)
results.append({
"score": score,
"label": "true" if score > 0.5 else "false"
})
return results
def _score_single(self, input_text: str) -> float:
"""Score a single query-document pair."""
# Tokenize
inputs = self.tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
).to(self.device)
# Get logits for first generated token
with torch.no_grad():
decoder_input_ids = torch.tensor(
[[self.tokenizer.pad_token_id]],
device=self.device
)
outputs = self.model(
**inputs,
decoder_input_ids=decoder_input_ids
)
logits = outputs.logits[0, -1, :]
# Get probabilities for true/false tokens
true_logit = logits[self.true_id].item()
false_logit = logits[self.false_id].item()
# Softmax over true/false
max_logit = max(true_logit, false_logit)
true_prob = math.exp(true_logit - max_logit)
false_prob = math.exp(false_logit - max_logit)
return true_prob / (true_prob + false_prob)
|