File size: 2,977 Bytes
651552b
ee87074
00de14a
651552b
ee87074
 
651552b
00de14a
ee87074
00de14a
ee87074
2f31341
651552b
00de14a
 
e11bd6e
f41d690
6f4032a
 
 
00de14a
e11bd6e
 
 
 
 
 
 
 
 
 
 
651552b
ee87074
 
00de14a
ee87074
 
 
 
 
e11bd6e
 
651552b
ee87074
 
00de14a
 
ee87074
9266488
ee87074
3f9b161
9266488
ee87074
 
3f9b161
ee87074
d51dabc
 
ee87074
4a454e5
651552b
ee87074
 
 
 
 
 
 
6848d2f
ee87074
f41d690
7a04fc2
ee87074
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load the shared tokenizer (using a tokenizer from DialoGPT models)
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")

# Define the model names, including the locally saved fine-tuned model
model_names = {
    "DialoGPT-med-FT": "DialoGPT-med-FT.bin",
    "DialoGPT-medium": "microsoft/DialoGPT-medium"
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the default model initially
current_model_name = "DialoGPT-med-FT"
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
model.load_state_dict(torch.load(model_names[current_model_name], map_location=device))
model.to(device)

def load_model(model_name):
    global model, current_model_name
    if model_name != current_model_name:
        # Load the new model and update the current model reference
        if model_name == "DialoGPT-medium":
            model = AutoModelForCausalLM.from_pretrained(model_names[model_name]).to(device)
        elif model_name == "DialoGPT-med-FT":
            model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
            model.load_state_dict(torch.load(model_names[model_name], map_location=device))
            model.to(device)
        current_model_name = model_name

def respond(
    message,
    history: list[dict],
    model_choice,
    max_tokens,
    temperature,
    top_p,
):
    # Load the selected model if it's different from the current one
    load_model(model_choice)

    # Prepare the input by concatenating the history into a dialogue format
    input_text = ""
    for message_pair in history:
        input_text += f"{message_pair['role']}: {message_pair['content']}\n"
    input_text += f"User: {message}\nAssistant:"

    # Tokenize the input text using the shared tokenizer
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)

    # Generate the response using the selected DialoGPT model
    output_tokens = model.generate(
        inputs["input_ids"].to(device),
        max_length=len(inputs["input_ids"][0]) + max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )

    # Decode and return the assistant's response
    response = tokenizer.decode(output_tokens[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
    yield response

# Define the Gradio interface
demo = gr.ChatInterface(
    respond,
    type='messages',
    additional_inputs=[
        gr.Dropdown(choices=["DialoGPT-med-FT", "DialoGPT-medium"], value="DialoGPT-med-FT", label="Model"),
        gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()