kdevoe commited on
Commit
f41d690
·
verified ·
1 Parent(s): e11bd6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -14,10 +14,8 @@ model_names = {
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  # Load the default model initially
17
- current_model_name = "DialoGPT-medium"
18
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
19
- model.load_state_dict(torch.load(model_names[current_model_name], map_location=device))
20
- model.to(device)
21
 
22
  def load_model(model_name):
23
  global model, current_model_name
@@ -69,7 +67,7 @@ demo = gr.ChatInterface(
69
  respond,
70
  type='messages',
71
  additional_inputs=[
72
- gr.Dropdown(choices=["DialoGPT-med-FT", "DialoGPT-medium"], value="DialoGPT-medium", label="Model"),
73
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
74
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
75
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  # Load the default model initially
17
+ current_model_name = "DialoGPT-med-FT"
18
+ model = AutoModelForCausalLM.from_pretrained(model_names[current_model_name]).to(device)
 
 
19
 
20
  def load_model(model_name):
21
  global model, current_model_name
 
67
  respond,
68
  type='messages',
69
  additional_inputs=[
70
+ gr.Dropdown(choices=["DialoGPT-med-FT", "DialoGPT-medium"], value="DialoGPT-med-FT", label="Model"),
71
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
72
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
73
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),