BlueDice commited on
Commit
c3eb599
·
1 Parent(s): f26a708

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +16 -6
code/inference.py CHANGED
@@ -3,25 +3,33 @@ import re
3
  import torch
4
 
5
  def model_fn(model_dir):
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
7
  model = torch.load(f"{model_dir}/torch_model.pt")
8
- return model, tokenizer
 
9
 
10
  def predict_fn(data, load_list):
11
- model, tokenizer = load_list
 
 
 
 
12
  request_inputs = data.pop("inputs", data)
13
- template = request_inputs["template"]
14
  messages = request_inputs["messages"]
15
  char_name = request_inputs["char_name"]
16
  user_name = request_inputs["user_name"]
17
- template = open(f"{template}.txt", "r").read()
18
- user_input = "\n".join([
19
  "{name}: {message}".format(
20
  name = char_name if (id["role"] == "AI") else user_name,
21
  message = id["message"].strip()
22
  ) for id in messages
23
- ])
 
24
  prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
 
 
25
  input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
26
  encoded_output = model.generate(
27
  input_ids["input_ids"],
@@ -34,6 +42,8 @@ def predict_fn(data, load_list):
34
  num_return_sequences = 1
35
  )
36
  decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
 
 
37
  decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
38
  parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
39
  if len(parsed_result) != 0: decoded_output = parsed_result
 
3
  import torch
4
 
5
  def model_fn(model_dir):
6
+
7
+ # Load Tokenizer, Model and Default template
8
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
9
  model = torch.load(f"{model_dir}/torch_model.pt")
10
+ template = open(f"{model_dir}/default_template.txt","r").read()
11
+ return model, tokenizer, template
12
 
13
  def predict_fn(data, load_list):
14
+
15
+ # Get model, tokenzier and template from the model_fn
16
+ model, tokenizer, template = load_list
17
+
18
+ # Parse the input request into correct format to generate model input
19
  request_inputs = data.pop("inputs", data)
 
20
  messages = request_inputs["messages"]
21
  char_name = request_inputs["char_name"]
22
  user_name = request_inputs["user_name"]
23
+ user_input = [
 
24
  "{name}: {message}".format(
25
  name = char_name if (id["role"] == "AI") else user_name,
26
  message = id["message"].strip()
27
  ) for id in messages
28
+ ]
29
+ user_input = "\n".join([user_input])
30
  prompt = template.format(char_name = char_name, user_name = user_name, user_input = user_input)
31
+
32
+ # tokenize the model input, generate and decode output
33
  input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
34
  encoded_output = model.generate(
35
  input_ids["input_ids"],
 
42
  num_return_sequences = 1
43
  )
44
  decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
45
+
46
+ # Parse the decoded output to the expected response
47
  decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
48
  parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
49
  if len(parsed_result) != 0: decoded_output = parsed_result