File size: 2,266 Bytes
cef3e59
 
 
 
 
 
 
 
86c8190
e890fdc
cef3e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import Dict, Any, List, Union

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TextClassificationPipeline,
)


class EndpointHandler:
    """Custom handler for Hugging Face Inference Endpoints.

    Loads a fine-tuned text-classification model and exposes a callable
    that the endpoint runtime will invoke. The runtime will instantiate
    this class once at startup, passing the model directory path.
    """

    def __init__(self, path: str = "", **kwargs):
        # `path` is the directory where the model artefacts are stored.
        # Fallback to current directory if not provided (local testing).
        model_dir = path or "."

        # Load tokenizer & model
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)

        # Build a text-classification pipeline
        self.pipeline = TextClassificationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
            device=-1,  # CPU; the runtime sets CUDA if available automatically
            return_all_scores=False,
            function_to_apply="sigmoid"
            if getattr(self.model.config, "problem_type", None)
            == "multi_label_classification"
            else "softmax",
        )

    def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
        """Run inference on the incoming request.

        Expected input format from the Inference Endpoint runtime:
            {
              "inputs": "some text" | ["text 1", "text 2", ...],
              "parameters": { ... }  # optional pipeline kwargs (e.g., top_k)
            }
        """
        # Extract the text(s)
        raw_inputs = data.get("inputs", data)
        if isinstance(raw_inputs, str):
            raw_inputs = [raw_inputs]

        # Additional pipeline parameters (optional)
        parameters = data.get("parameters", {})

        # Execute the pipeline
        outputs = self.pipeline(raw_inputs, **parameters)

        # If only one input was provided, return a single dict for convenience
        if len(outputs) == 1:
            return outputs[0]
        return outputs