Upload handler.py
Browse files- handler.py +5 -0
handler.py
CHANGED
|
@@ -16,6 +16,7 @@ class EndpointHandler():
|
|
| 16 |
config = PeftConfig.from_pretrained(path)
|
| 17 |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto')
|
| 18 |
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
|
|
|
| 19 |
# Load the Lora model
|
| 20 |
self.model = PeftModel.from_pretrained(model, path)
|
| 21 |
|
|
@@ -44,6 +45,8 @@ class EndpointHandler():
|
|
| 44 |
|
| 45 |
# Call the generate function
|
| 46 |
output = generate(
|
|
|
|
|
|
|
| 47 |
message=message,
|
| 48 |
chat_history=chat_history,
|
| 49 |
system_prompt=system_prompt,
|
|
@@ -63,6 +66,8 @@ class EndpointHandler():
|
|
| 63 |
return {"generated_text": prediction}
|
| 64 |
|
| 65 |
def generate(
|
|
|
|
|
|
|
| 66 |
message: str,
|
| 67 |
chat_history: list[tuple[str, str]],
|
| 68 |
system_prompt: str = None,
|
|
|
|
| 16 |
config = PeftConfig.from_pretrained(path)
|
| 17 |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto')
|
| 18 |
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
| 19 |
+
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ bos_token + '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}"
|
| 20 |
# Load the Lora model
|
| 21 |
self.model = PeftModel.from_pretrained(model, path)
|
| 22 |
|
|
|
|
| 45 |
|
| 46 |
# Call the generate function
|
| 47 |
output = generate(
|
| 48 |
+
tokenizer=tokenizer,
|
| 49 |
+
model=model,
|
| 50 |
message=message,
|
| 51 |
chat_history=chat_history,
|
| 52 |
system_prompt=system_prompt,
|
|
|
|
| 66 |
return {"generated_text": prediction}
|
| 67 |
|
| 68 |
def generate(
|
| 69 |
+
tokenizer,
|
| 70 |
+
model,
|
| 71 |
message: str,
|
| 72 |
chat_history: list[tuple[str, str]],
|
| 73 |
system_prompt: str = None,
|