ChevalierJoseph commited on
Commit
e8fdd65
·
verified ·
1 Parent(s): d93c503

Update handle.py

Browse files
Files changed (1) hide show
  1. handle.py +26 -52
handle.py CHANGED
@@ -1,55 +1,29 @@
1
- from typing import Dict, Any
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
  class EndpointHandler:
6
- def __init__(self, model_dir: str, **kwargs):
7
- # Charger le tokenizer et le modèle
8
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
9
- self.model = AutoModelForCausalLM.from_pretrained(
10
- model_dir,
11
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
12
- device_map="auto",
13
- trust_remote_code=True
14
- )
15
-
16
- # Paramètres de génération
17
- self.generation_config = {
18
- "max_new_tokens": 512,
19
- "do_sample": True,
20
- "temperature": 0.7,
21
- "top_p": 0.9,
22
- "top_k": 50,
23
- "pad_token_id": self.tokenizer.eos_token_id,
24
- "eos_token_id": self.tokenizer.eos_token_id,
25
- }
26
-
27
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
28
- inputs = data.get("inputs") or data.get("text")
29
- if not inputs:
30
- return {"error": "No 'inputs' provided in request."}
31
-
32
- # Message format type ChatML
33
- messages = [{"role": "user", "content": inputs}]
34
-
35
- # Appliquer le template si possible
36
- if hasattr(self.tokenizer, "apply_chat_template"):
37
- prompt = self.tokenizer.apply_chat_template(
38
- messages,
39
- tokenize=False,
40
- add_generation_prompt=True
41
- )
42
- else:
43
- # Fallback simple si pas de template
44
- prompt = "user: " + inputs + "\nassistant:"
45
-
46
- # Tokeniser et générer
47
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
48
-
49
- with torch.no_grad():
50
- output_ids = self.model.generate(input_ids, **self.generation_config)
51
-
52
- # Décoder la sortie après le prompt
53
- response = self.tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
54
-
55
- return {"generated_text": response}
 
1
+ from typing import Dict, List, Any
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
4
  class EndpointHandler:
5
+ def __init__(self, path: str):
6
+ # Charger le modèle et le tokenizer
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForCausalLM.from_pretrained(path)
9
+
10
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
+ """
12
+ Cette méthode est appelée à chaque requête.
13
+ :param data: un dictionnaire contenant les données d'entrée.
14
+ :return: un dictionnaire contenant la prédiction.
15
+ """
16
+ # Extraire les entrées du dictionnaire de données
17
+ inputs = data.pop("inputs", data)
18
+
19
+ # Tokenize les entrées
20
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
21
+
22
+ # Générer du texte
23
+ output_ids = self.model.generate(input_ids, max_length=100)
24
+
25
+ # Décoder les IDs de sortie en texte
26
+ generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
27
+
28
+ # Retourner le texte généré
29
+ return {"generated_text": generated_text}