Maliketh28 commited on
Commit
d744da6
·
verified ·
1 Parent(s): e58fd60

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -9
handler.py CHANGED
@@ -1,15 +1,18 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
 
3
 
4
  class EndpointHandler:
5
- def __init__(self):
6
- self.model = AutoModelForSequenceClassification.from_pretrained(".")
7
- self.tokenizer = AutoTokenizer.from_pretrained(".")
 
 
8
  self.labels = ["presentation","projects","skills","education","contact","fallback"]
 
 
9
 
10
  def inference(self, inputs):
11
- encoded = self.tokenizer(inputs, return_tensors="pt")
12
- with torch.no_grad():
13
- logits = self.model(**encoded).logits
14
- predicted_class = torch.argmax(logits, dim=1).item()
15
- return {"label": self.labels[predicted_class], "score": 1.0}
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
2
  import torch
3
+ import os
4
 
5
  class EndpointHandler:
6
+ def __init__(self, model_dir):
7
+ # Hugging Face passe le répertoire du modèle ici
8
+ self.model_dir = model_dir
9
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
  self.labels = ["presentation","projects","skills","education","contact","fallback"]
12
+ # Optionnel : créer un pipeline pour simplifier l'inférence
13
+ self.classifier = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
14
 
15
  def inference(self, inputs):
16
+ # inputs est une chaîne de texte
17
+ outputs = self.classifier(inputs)
18
+ return {"label": outputs[0]["label"], "score": float(outputs[0]["score"])}