File size: 2,451 Bytes
62a14f0 78bd8dc d4c23b0 c3eb599 ec798e1 c3eb599 d4c23b0 91f9c84 c3eb599 b1c81ae f149e87 b4b37be c3eb599 f149e87 c3eb599 b4b37be ec798e1 d4c23b0 78bd8dc c3eb599 f149e87 ec798e1 78bd8dc f149e87 ec798e1 78bd8dc ec798e1 f149e87 b4b37be f149e87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | from transformers import AutoTokenizer
import re
import torch
def model_fn(model_dir):
# Load Tokenizer, Model and Default template
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = torch.load(f"{model_dir}/torch_model.pt")
template = open(f"{model_dir}/default_template.txt","r").read()
return model, tokenizer, template
def predict_fn(data, load_list):
# Get model, tokenzier and template from the model_fn
model, tokenizer, template = load_list
# Parse the input request into correct format to generate model input
request_inputs = data.pop("inputs", data)
messages = request_inputs["messages"]
char_name = request_inputs["char_name"]
user_name = request_inputs["user_name"]
chats_curled = request_inputs["chats_curled"]
user_input = [
"{name}: {message}".format(
name = char_name if (id["role"] == "AI") else user_name,
message = id["message"].strip()
) for id in messages
]
# Tokenize the model input
while True:
prompt = template.format(char_name = char_name, user_name = user_name, user_input = "\n".join([user_input]))
input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
if input_ids.input_ids.size(1) > 2048:
chats_curled += 1
user_input = user_input[chats_curled*2:]
else: break
encoded_output = model.generate(
input_ids["input_ids"],
max_new_tokens = 50,
temperature = 0.5,
top_p = 0.9,
top_k = 0,
repetition_penalty = 1.1,
pad_token_id = 50256,
num_return_sequences = 1
)
decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")
# Parse the decoded output to the expected response
decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
if len(parsed_result) != 0: decoded_output = parsed_result
decoded_output = " ".join(decoded_output.replace("*","").split())
try:
parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
if len(parsed_result) != 0: decoded_output = parsed_result
except Exception: pass
return {
"role": "AI",
"message": decoded_output,
"chats_curled": chats_curled
} |