File size: 882 Bytes
d744da6
598bf7e
e58fd60
d744da6
 
 
 
598bf7e
d744da6
598bf7e
0362e65
 
 
 
 
 
 
d744da6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

class EndpointHandler:
    def __init__(self, model_dir):
        self.model_dir = model_dir
        self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.labels = ["presentation","projects","skills","education","contact","fallback"]
        self.classifier = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)

    def __call__(self, request):
        """
        Hugging Face Default container attend __call__ comme point d'entrée.
        `request` est le payload JSON reçu par l'endpoint.
        """
        text = request.get("inputs", "")
        outputs = self.classifier(text)
        return {"label": outputs[0]["label"], "score": float(outputs[0]["score"])}