Update handler.py
Browse files- handler.py +35 -19
handler.py
CHANGED
|
@@ -6,6 +6,8 @@ import re
|
|
| 6 |
from transformers import AutoTokenizer, TextStreamer
|
| 7 |
from unsloth import FastLanguageModel
|
| 8 |
from peft import PeftModel
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class EndpointHandler:
|
| 11 |
def __init__(self, model_dir):
|
|
@@ -67,13 +69,15 @@ class EndpointHandler:
|
|
| 67 |
self.policy_prompt = self._get_policy_prompt()
|
| 68 |
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def _get_policy_prompt(self):
|
|
@@ -229,27 +233,39 @@ class EndpointHandler:
|
|
| 229 |
# Format input with the conversation template based on model type
|
| 230 |
formatted_input = f"Please assess the following conversation: {input_text}"
|
| 231 |
conversation = self._format_conversations(formatted_input)
|
| 232 |
-
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
| 240 |
|
| 241 |
# Tokenize input and move to the same device as the model
|
| 242 |
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 243 |
|
| 244 |
-
# Generate response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
with torch.no_grad():
|
| 246 |
-
text_streamer = TextStreamer(self.tokenizer,skip_prompt=False)
|
| 247 |
output = self.model.generate(
|
| 248 |
-
**inputs,
|
| 249 |
-
|
| 250 |
-
|
|
|
|
| 251 |
)
|
| 252 |
|
|
|
|
|
|
|
| 253 |
# Decode the output
|
| 254 |
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
| 255 |
|
|
|
|
| 6 |
from transformers import AutoTokenizer, TextStreamer
|
| 7 |
from unsloth import FastLanguageModel
|
| 8 |
from peft import PeftModel
|
| 9 |
+
from unsloth.chat_templates import get_chat_template
|
| 10 |
+
|
| 11 |
|
| 12 |
class EndpointHandler:
|
| 13 |
def __init__(self, model_dir):
|
|
|
|
| 69 |
self.policy_prompt = self._get_policy_prompt()
|
| 70 |
|
| 71 |
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# try:
|
| 75 |
+
# template_tokenizer = self.tokenizer
|
| 76 |
+
# if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template:
|
| 77 |
+
# self.tokenizer.chat_template = template_tokenizer.chat_template
|
| 78 |
+
# print(f"Successfully imported chat template from {self.chat_template_id}")
|
| 79 |
+
# except Exception as e:
|
| 80 |
+
# print(f"Failed to import chat template: {e}")
|
| 81 |
|
| 82 |
|
| 83 |
def _get_policy_prompt(self):
|
|
|
|
| 233 |
# Format input with the conversation template based on model type
|
| 234 |
formatted_input = f"Please assess the following conversation: {input_text}"
|
| 235 |
conversation = self._format_conversations(formatted_input)
|
| 236 |
+
|
| 237 |
|
| 238 |
+
self.tokenizer = get_chat_template(
|
| 239 |
+
self.tokenizer,
|
| 240 |
+
chat_template = chat_template,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
prompt = self.tokenizer.apply_chat_template(conversation["conversations"], tokenize=False)
|
| 245 |
+
|
| 246 |
|
| 247 |
# Tokenize input and move to the same device as the model
|
| 248 |
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 249 |
|
| 250 |
+
# # Generate response
|
| 251 |
+
# with torch.no_grad():
|
| 252 |
+
# text_streamer = TextStreamer(self.tokenizer,skip_prompt=False)
|
| 253 |
+
# output = self.model.generate(
|
| 254 |
+
# **inputs,
|
| 255 |
+
# streamer=text_streamer,
|
| 256 |
+
# max_new_tokens=512
|
| 257 |
+
# )
|
| 258 |
+
|
| 259 |
with torch.no_grad():
|
|
|
|
| 260 |
output = self.model.generate(
|
| 261 |
+
**inputs,
|
| 262 |
+
max_new_tokens=512,
|
| 263 |
+
do_sample=False,
|
| 264 |
+
temperature=0.2,
|
| 265 |
)
|
| 266 |
|
| 267 |
+
|
| 268 |
+
|
| 269 |
# Decode the output
|
| 270 |
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
| 271 |
|