File size: 882 Bytes
d744da6 598bf7e e58fd60 d744da6 598bf7e d744da6 598bf7e 0362e65 d744da6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
class EndpointHandler:
def __init__(self, model_dir):
self.model_dir = model_dir
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.labels = ["presentation","projects","skills","education","contact","fallback"]
self.classifier = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
def __call__(self, request):
"""
Hugging Face Default container attend __call__ comme point d'entrée.
`request` est le payload JSON reçu par l'endpoint.
"""
text = request.get("inputs", "")
outputs = self.classifier(text)
return {"label": outputs[0]["label"], "score": float(outputs[0]["score"])}
|