File size: 885 Bytes
b8473d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import Dict, List, Any
from transformers import pipeline, AutoTokenizer

class EndpointHandler:
    def __init__(self, path=""):
        # Load the optimized model
        tokenizer = AutoTokenizer.from_pretrained(path)
        # Create inference pipeline for text classification
        self.pipeline = pipeline("text-classification", model=path, tokenizer=tokenizer)

    def __call__(self, data: str) -> List[List[Dict[str, float]]]:
        """
        Args:
            data (str): A raw string input for inference.
        Returns:
            A list containing the prediction results:
            A list of one list, e.g., [[{"label": "LABEL", "score": 0.99}]]
        """
        # Pass the data as `text` directly
        inputs = data.pop("inputs", data)
        prediction = self.pipeline(inputs)

        # Return the prediction result
        return prediction