codys12 commited on
Commit
1deae37
·
1 Parent(s): 97a48a0

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -0
handler.py CHANGED
@@ -16,6 +16,7 @@ class EndpointHandler():
16
  config = PeftConfig.from_pretrained(path)
17
  model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto')
18
  self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
 
19
  # Load the Lora model
20
  self.model = PeftModel.from_pretrained(model, path)
21
 
@@ -44,6 +45,8 @@ class EndpointHandler():
44
 
45
  # Call the generate function
46
  output = generate(
 
 
47
  message=message,
48
  chat_history=chat_history,
49
  system_prompt=system_prompt,
@@ -63,6 +66,8 @@ class EndpointHandler():
63
  return {"generated_text": prediction}
64
 
65
  def generate(
 
 
66
  message: str,
67
  chat_history: list[tuple[str, str]],
68
  system_prompt: str = None,
 
16
  config = PeftConfig.from_pretrained(path)
17
  model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto')
18
  self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
19
+ self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ bos_token + '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}"
20
  # Load the Lora model
21
  self.model = PeftModel.from_pretrained(model, path)
22
 
 
45
 
46
  # Call the generate function
47
  output = generate(
48
+ tokenizer=tokenizer,
49
+ model=model,
50
  message=message,
51
  chat_history=chat_history,
52
  system_prompt=system_prompt,
 
66
  return {"generated_text": prediction}
67
 
68
  def generate(
69
+ tokenizer,
70
+ model,
71
  message: str,
72
  chat_history: list[tuple[str, str]],
73
  system_prompt: str = None,