kdevoe commited on
Commit
6f4032a
·
verified ·
1 Parent(s): f41d690

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -15,7 +15,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  # Load the default model initially
17
  current_model_name = "DialoGPT-med-FT"
18
- model = AutoModelForCausalLM.from_pretrained(model_names[current_model_name]).to(device)
 
 
19
 
20
  def load_model(model_name):
21
  global model, current_model_name
 
15
 
16
  # Load the default model initially
17
  current_model_name = "DialoGPT-med-FT"
18
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
19
+ model.load_state_dict(torch.load(model_names[current_model_name], map_location=device))
20
+ model.to(device)
21
 
22
  def load_model(model_name):
23
  global model, current_model_name