ChevalierJoseph commited on
Commit
b815173
·
verified ·
1 Parent(s): 8efe53b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -19
handler.py CHANGED
@@ -1,29 +1,22 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- class EndpointHandler:
5
- def __init__(self, path: str):
6
- # Charger le modèle et le tokenizer
7
- self.tokenizer = AutoTokenizer.from_pretrained(path)
8
  self.model = AutoModelForCausalLM.from_pretrained(path)
 
9
 
10
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
- """
12
- Cette méthode est appelée à chaque requête.
13
- :param data: un dictionnaire contenant les données d'entrée.
14
- :return: un dictionnaire contenant la prédiction.
15
- """
16
- # Extraire les entrées du dictionnaire de données
17
  inputs = data.pop("inputs", data)
18
 
19
- # Tokenize les entrées
20
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
 
21
 
22
- # Générer du texte
23
- output_ids = self.model.generate(input_ids, max_length=100)
24
-
25
- # Décoder les IDs de sortie en texte
26
- generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
27
 
28
- # Retourner le texte généré
29
- return {"generated_text": generated_text}
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ class EndpointHandler():
5
+ def __init__(self, path=""):
6
+ # Load the model and tokenizer from the specified path
 
7
  self.model = AutoModelForCausalLM.from_pretrained(path)
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
9
 
10
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
11
+ # Extract input text from the request
 
 
 
 
 
12
  inputs = data.pop("inputs", data)
13
 
14
+ # Tokenize input and generate text
15
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
16
+ output_ids = self.model.generate(input_ids)
17
 
18
+ # Decode the generated output
19
+ output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
20
 
21
+ # Return the generated text
22
+ return [{"generated_text": output_text}]