File size: 4,667 Bytes
753d5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Custom handler for HuggingFace Inference Endpoints.

Accepts a context string and a list of candidate sentences,
tokenizes them in batches, scores each sentence, and returns
the scores.

Expected input JSON:
{
    "inputs": {
        "context": "The Crash at Crush was a publicity stunt in Texas in 1896.",
        "sentences": [
            "An estimated 40,000 people attended the event.",
            "The event was held on September 15.",
            "Two people were killed by flying debris."
        ]
    }
}

Response JSON:
[
    {"sentence": "An estimated 40,000 people attended the event.", "score": 1.234},
    {"sentence": "The event was held on September 15.", "score": 0.456},
    {"sentence": "Two people were killed by flying debris.", "score": 1.789}
]
"""

from typing import Any, Dict, List, Union
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

MAX_LENGTH = 384
BATCH_SIZE = 32


class EndpointHandler:
    """Custom handler for sentence interestingness scoring."""

    def __init__(self, path: str = ""):
        """Load the model and tokenizer from the given path.

        Args:
            path: Path to the model directory (provided by the Inference Endpoint).
        """
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.model.eval()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, str]]:
        """Score a list of sentences given a context.

        Args:
            data: Request payload. Expected shape:
                {
                    "inputs": {
                        "context": str,
                        "sentences": list[str]
                    }
                }
                OR (flat form):
                {
                    "inputs": str  # treated as context, sentences split by newlines
                }

        Returns:
            List of dicts with "sentence" and "score" keys,
            sorted by score descending.
        """
        # Use pop like HF's example handlers to be resilient to wrapper layers
        inputs = data.pop("inputs", data)
        # Also grab parameters if they exist (HF Endpoints sometimes pass them separately)
        parameters = data.pop("parameters", {})

        # Support both structured and simple string input
        if isinstance(inputs, str):
            # Simple mode: treat input as context, split into sentences
            try:
                import nltk

                nltk.download("punkt_tab", quiet=True)
                context = inputs
                sentences = nltk.sent_tokenize(inputs)
            except ImportError:
                return {"error": "Structured input required: provide 'context' and 'sentences' fields."}
        elif isinstance(inputs, dict):
            context = inputs.get("context", "")
            sentences = inputs.get("sentences", [])
        else:
            return {"error": "Unexpected input type: {}".format(type(inputs).__name__)}

        if not context:
            return {"error": "No context provided."}
        if not sentences:
            return {"error": "No sentences provided."}

        # Score sentences in batches
        all_scores = []  # type: List[float]

        for batch_start in range(0, len(sentences), BATCH_SIZE):
            batch_sentences = sentences[batch_start : batch_start + BATCH_SIZE]

            # Tokenize the batch: each item is (context, sentence) pair
            encoded = self.tokenizer(
                [context] * len(batch_sentences),
                batch_sentences,
                return_tensors="pt",
                truncation=True,
                padding=True,
                max_length=MAX_LENGTH,
            )
            encoded = {k: v.to(self.device) for k, v in encoded.items()}

            with torch.no_grad():
                outputs = self.model(**encoded)
                scores = outputs.logits.squeeze(-1)  # (batch_size,)

                # Handle single-item batch (squeeze removes the dim entirely)
                if scores.dim() == 0:
                    scores = scores.unsqueeze(0)

                all_scores.extend(scores.cpu().tolist())

        # Build results sorted by score (highest first)
        results = [
            {"sentence": sent, "score": round(score, 4)}
            for sent, score in zip(sentences, all_scores)
        ]
        results.sort(key=lambda x: x["score"], reverse=True)

        return results