Machlovi commited on
Commit
70e8fa3
·
verified ·
1 Parent(s): ef1cb6b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -19
handler.py CHANGED
@@ -6,6 +6,8 @@ import re
6
  from transformers import AutoTokenizer, TextStreamer
7
  from unsloth import FastLanguageModel
8
  from peft import PeftModel
 
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir):
@@ -67,13 +69,15 @@ class EndpointHandler:
67
  self.policy_prompt = self._get_policy_prompt()
68
 
69
 
70
- try:
71
- template_tokenizer = self.tokenizer
72
- if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template:
73
- self.tokenizer.chat_template = template_tokenizer.chat_template
74
- print(f"Successfully imported chat template from {self.chat_template_id}")
75
- except Exception as e:
76
- print(f"Failed to import chat template: {e}")
 
 
77
 
78
 
79
  def _get_policy_prompt(self):
@@ -229,27 +233,39 @@ class EndpointHandler:
229
  # Format input with the conversation template based on model type
230
  formatted_input = f"Please assess the following conversation: {input_text}"
231
  conversation = self._format_conversations(formatted_input)
232
- self.tokenizer.chat_template = self.chat_template
233
 
234
- # Apply the chat template to prepare for the model
235
- if hasattr(self.tokenizer, "apply_chat_template"):
236
- prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False)
237
- else:
238
- # Fallback if apply_chat_template is not available
239
- prompt = f"System: {self.policy_prompt}\nUser: {formatted_input}"
 
 
240
 
241
  # Tokenize input and move to the same device as the model
242
  inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
243
 
244
- # Generate response
 
 
 
 
 
 
 
 
245
  with torch.no_grad():
246
- text_streamer = TextStreamer(self.tokenizer,skip_prompt=False)
247
  output = self.model.generate(
248
- **inputs,
249
- streamer=text_streamer,
250
- max_new_tokens=512
 
251
  )
252
 
 
 
253
  # Decode the output
254
  decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
255
 
 
6
  from transformers import AutoTokenizer, TextStreamer
7
  from unsloth import FastLanguageModel
8
  from peft import PeftModel
9
+ from unsloth.chat_templates import get_chat_template
10
+
11
 
12
  class EndpointHandler:
13
  def __init__(self, model_dir):
 
69
  self.policy_prompt = self._get_policy_prompt()
70
 
71
 
72
+
73
+
74
+ # try:
75
+ # template_tokenizer = self.tokenizer
76
+ # if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template:
77
+ # self.tokenizer.chat_template = template_tokenizer.chat_template
78
+ # print(f"Successfully imported chat template from {self.chat_template_id}")
79
+ # except Exception as e:
80
+ # print(f"Failed to import chat template: {e}")
81
 
82
 
83
  def _get_policy_prompt(self):
 
233
  # Format input with the conversation template based on model type
234
  formatted_input = f"Please assess the following conversation: {input_text}"
235
  conversation = self._format_conversations(formatted_input)
236
+
237
 
238
+ self.tokenizer = get_chat_template(
239
+ self.tokenizer,
240
+ chat_template = chat_template,
241
+ )
242
+
243
+
244
+ prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False)
245
+
246
 
247
  # Tokenize input and move to the same device as the model
248
  inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
249
 
250
+ # # Generate response
251
+ # with torch.no_grad():
252
+ # text_streamer = TextStreamer(self.tokenizer,skip_prompt=False)
253
+ # output = self.model.generate(
254
+ # **inputs,
255
+ # streamer=text_streamer,
256
+ # max_new_tokens=512
257
+ # )
258
+
259
  with torch.no_grad():
 
260
  output = self.model.generate(
261
+ **inputs,
262
+ max_new_tokens=512,
263
+ do_sample=False,
264
+ temperature=0.2,
265
  )
266
 
267
+
268
+
269
  # Decode the output
270
  decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
271