ChevalierJoseph commited on
Commit
d93c503
·
verified ·
1 Parent(s): 15b7923

Update handle.py

Browse files
Files changed (1) hide show
  1. handle.py +18 -11
handle.py CHANGED
@@ -1,10 +1,10 @@
1
  from typing import Dict, Any
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
  import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir: str, **kwargs):
7
- # Load tokenizer and model
8
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
9
  self.model = AutoModelForCausalLM.from_pretrained(
10
  model_dir,
@@ -13,10 +13,7 @@ class EndpointHandler:
13
  trust_remote_code=True
14
  )
15
 
16
- # Chat template setup
17
- self.tokenizer.chat_template = self.tokenizer.chat_template or "{% for message in messages %}{{ message['role'] }}: {{ message['content'] }}\n{% endfor %}assistant:"
18
-
19
- # Generation config (can be tuned)
20
  self.generation_config = {
21
  "max_new_tokens": 512,
22
  "do_sample": True,
@@ -32,17 +29,27 @@ class EndpointHandler:
32
  if not inputs:
33
  return {"error": "No 'inputs' provided in request."}
34
 
35
- # Wrap input in a basic chat message structure
36
  messages = [{"role": "user", "content": inputs}]
37
 
38
- # Apply chat template
39
- prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
40
-
 
 
 
 
 
 
 
 
 
41
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
42
 
43
  with torch.no_grad():
44
  output_ids = self.model.generate(input_ids, **self.generation_config)
45
 
46
- # Decode output
47
  response = self.tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
 
48
  return {"generated_text": response}
 
1
  from typing import Dict, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir: str, **kwargs):
7
+ # Charger le tokenizer et le modèle
8
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
9
  self.model = AutoModelForCausalLM.from_pretrained(
10
  model_dir,
 
13
  trust_remote_code=True
14
  )
15
 
16
+ # Paramètres de génération
 
 
 
17
  self.generation_config = {
18
  "max_new_tokens": 512,
19
  "do_sample": True,
 
29
  if not inputs:
30
  return {"error": "No 'inputs' provided in request."}
31
 
32
+ # Message format type ChatML
33
  messages = [{"role": "user", "content": inputs}]
34
 
35
+ # Appliquer le template si possible
36
+ if hasattr(self.tokenizer, "apply_chat_template"):
37
+ prompt = self.tokenizer.apply_chat_template(
38
+ messages,
39
+ tokenize=False,
40
+ add_generation_prompt=True
41
+ )
42
+ else:
43
+ # Fallback simple si pas de template
44
+ prompt = "user: " + inputs + "\nassistant:"
45
+
46
+ # Tokeniser et générer
47
  input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
48
 
49
  with torch.no_grad():
50
  output_ids = self.model.generate(input_ids, **self.generation_config)
51
 
52
+ # Décoder la sortie après le prompt
53
  response = self.tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
54
+
55
  return {"generated_text": response}