File size: 3,453 Bytes
0518d49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Custom handler for MonoT5 reranking on HuggingFace Inference Endpoints.

Returns relevance probability scores for query-document pairs.
"""

import math
from typing import Any, Dict, List

import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer


class EndpointHandler:
    """Handler for MonoT5 relevance scoring."""
    
    def __init__(self, path: str = ""):
        """Initialize the model and tokenizer."""
        self.tokenizer = T5Tokenizer.from_pretrained(path)
        self.model = T5ForConditionalGeneration.from_pretrained(path)
        self.model.eval()
        
        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        
        # Get token IDs for "true" and "false"
        self.true_id = self.tokenizer.encode("true", add_special_tokens=False)[0]
        self.false_id = self.tokenizer.encode("false", add_special_tokens=False)[0]
        
        print(f"MonoT5 loaded on {self.device}")
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process inference requests.
        
        Accepts either:
        - {"inputs": "Query: ... Document: ... Relevant:"} - single input
        - {"inputs": ["Query: ... Document: ... Relevant:", ...]} - batch
        - {"query": "...", "documents": ["...", ...]} - structured input
        
        Returns:
        - List of {"score": float, "label": "true"/"false"} dicts
        """
        inputs = data.get("inputs", data)
        
        # Handle structured input format
        if "query" in data and "documents" in data:
            query = data["query"]
            documents = data["documents"]
            inputs = [
                f"Query: {query} Document: {doc} Relevant:"
                for doc in documents
            ]
        
        # Ensure inputs is a list
        if isinstance(inputs, str):
            inputs = [inputs]
        
        # Score all inputs
        results = []
        for input_text in inputs:
            score = self._score_single(input_text)
            results.append({
                "score": score,
                "label": "true" if score > 0.5 else "false"
            })
        
        return results
    
    def _score_single(self, input_text: str) -> float:
        """Score a single query-document pair."""
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding=True
        ).to(self.device)
        
        # Get logits for first generated token
        with torch.no_grad():
            decoder_input_ids = torch.tensor(
                [[self.tokenizer.pad_token_id]], 
                device=self.device
            )
            outputs = self.model(
                **inputs,
                decoder_input_ids=decoder_input_ids
            )
            logits = outputs.logits[0, -1, :]
        
        # Get probabilities for true/false tokens
        true_logit = logits[self.true_id].item()
        false_logit = logits[self.false_id].item()
        
        # Softmax over true/false
        max_logit = max(true_logit, false_logit)
        true_prob = math.exp(true_logit - max_logit)
        false_prob = math.exp(false_logit - max_logit)
        
        return true_prob / (true_prob + false_prob)