File size: 2,451 Bytes
62a14f0
78bd8dc
d4c23b0
 
 
c3eb599
 
ec798e1
 
c3eb599
 
d4c23b0
91f9c84
c3eb599
 
 
 
 
b1c81ae
f149e87
 
 
b4b37be
c3eb599
f149e87
 
 
 
c3eb599
 
b4b37be
 
 
 
 
 
 
 
 
ec798e1
d4c23b0
 
 
 
 
 
 
 
 
78bd8dc
c3eb599
 
f149e87
ec798e1
78bd8dc
f149e87
ec798e1
78bd8dc
 
ec798e1
f149e87
 
b4b37be
 
f149e87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from transformers import AutoTokenizer
import re
import torch

def model_fn(model_dir):

    # Load Tokenizer, Model and Default template
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = torch.load(f"{model_dir}/torch_model.pt")
    template = open(f"{model_dir}/default_template.txt","r").read()
    return model, tokenizer, template

def predict_fn(data, load_list):

    # Get model, tokenzier and template from the model_fn
    model, tokenizer, template = load_list

    # Parse the input request into correct format to generate model input
    request_inputs = data.pop("inputs", data)
    messages = request_inputs["messages"]
    char_name = request_inputs["char_name"]
    user_name = request_inputs["user_name"]
    chats_curled = request_inputs["chats_curled"]
    user_input = [
        "{name}: {message}".format(
            name = char_name if (id["role"] == "AI") else user_name,
            message = id["message"].strip()
        ) for id in messages
    ]

    # Tokenize the model input
    while True:
        prompt = template.format(char_name = char_name, user_name = user_name, user_input = "\n".join([user_input]))
        input_ids = tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
        if input_ids.input_ids.size(1) > 2048:
            chats_curled += 1
            user_input = user_input[chats_curled*2:]
        else: break

    encoded_output = model.generate(
        input_ids["input_ids"],
        max_new_tokens = 50,
        temperature = 0.5,
        top_p = 0.9,
        top_k = 0,
        repetition_penalty = 1.1,
        pad_token_id = 50256,
        num_return_sequences = 1
    )
    decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(prompt,"")

    # Parse the decoded output to the expected response
    decoded_output = decoded_output.split(f"{char_name}:", 1)[1].split(f"{user_name}:",1)[0].strip()
    parsed_result = re.sub('\*.*?\*', '', decoded_output).strip()
    if len(parsed_result) != 0: decoded_output = parsed_result
    decoded_output = " ".join(decoded_output.replace("*","").split())
    try:
        parsed_result = decoded_output[:[m.start() for m in re.finditer(r'[.!?]', decoded_output)][-1]+1]
        if len(parsed_result) != 0: decoded_output = parsed_result
    except Exception: pass
    return {
        "role": "AI",
        "message": decoded_output,
        "chats_curled": chats_curled
    }