Maliketh28 commited on
Commit
0362e65
·
verified ·
1 Parent(s): 8349928

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -7
handler.py CHANGED
@@ -1,18 +1,18 @@
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
2
- import torch
3
- import os
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
7
- # Hugging Face passe le répertoire du modèle ici
8
  self.model_dir = model_dir
9
  self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
10
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
  self.labels = ["presentation","projects","skills","education","contact","fallback"]
12
- # Optionnel : créer un pipeline pour simplifier l'inférence
13
  self.classifier = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
14
 
15
- def inference(self, inputs):
16
- # inputs est une chaîne de texte
17
- outputs = self.classifier(inputs)
 
 
 
 
18
  return {"label": outputs[0]["label"], "score": float(outputs[0]["score"])}
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
 
2
 
3
  class EndpointHandler:
4
  def __init__(self, model_dir):
 
5
  self.model_dir = model_dir
6
  self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
7
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
8
  self.labels = ["presentation","projects","skills","education","contact","fallback"]
 
9
  self.classifier = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
10
 
11
+ def __call__(self, request):
12
+ """
13
+ Hugging Face Default container attend __call__ comme point d'entrée.
14
+ `request` est le payload JSON reçu par l'endpoint.
15
+ """
16
+ text = request.get("inputs", "")
17
+ outputs = self.classifier(text)
18
  return {"label": outputs[0]["label"], "score": float(outputs[0]["score"])}