kdevoe commited on
Commit
e11bd6e
·
verified ·
1 Parent(s): 3f9b161

Loading only one model at a time to conserve memory

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -13,14 +13,23 @@ model_names = {
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- # Pre-load the models
17
- loaded_models = {
18
- "DialoGPT-med-FT": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
19
- }
20
- loaded_models["DialoGPT-med-FT"].load_state_dict(torch.load(model_names["DialoGPT-med-FT"], map_location=device))
21
- loaded_models["DialoGPT-med-FT"].to(device)
22
 
23
- loaded_models["DialoGPT-medium"] = AutoModelForCausalLM.from_pretrained(model_names["DialoGPT-medium"]).to(device)
 
 
 
 
 
 
 
 
 
 
24
 
25
  def respond(
26
  message,
@@ -30,8 +39,8 @@ def respond(
30
  temperature,
31
  top_p,
32
  ):
33
- # Select the pre-loaded model based on user's choice
34
- model = loaded_models[model_choice]
35
 
36
  # Prepare the input by concatenating the history into a dialogue format
37
  input_text = ""
@@ -60,7 +69,7 @@ demo = gr.ChatInterface(
60
  respond,
61
  type='messages',
62
  additional_inputs=[
63
- gr.Dropdown(choices=["DialoGPT-med-FT", "DialoGPT-medium"], value="DialoGPT-med-FT", label="Model"),
64
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
65
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
66
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ # Load the default model initially
17
+ current_model_name = "DialoGPT-medium"
18
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
19
+ model.load_state_dict(torch.load(model_names[current_model_name], map_location=device))
20
+ model.to(device)
 
21
 
22
+ def load_model(model_name):
23
+ global model, current_model_name
24
+ if model_name != current_model_name:
25
+ # Load the new model and update the current model reference
26
+ if model_name == "DialoGPT-medium":
27
+ model = AutoModelForCausalLM.from_pretrained(model_names[model_name]).to(device)
28
+ elif model_name == "DialoGPT-med-FT":
29
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
30
+ model.load_state_dict(torch.load(model_names[model_name], map_location=device))
31
+ model.to(device)
32
+ current_model_name = model_name
33
 
34
  def respond(
35
  message,
 
39
  temperature,
40
  top_p,
41
  ):
42
+ # Load the selected model if it's different from the current one
43
+ load_model(model_choice)
44
 
45
  # Prepare the input by concatenating the history into a dialogue format
46
  input_text = ""
 
69
  respond,
70
  type='messages',
71
  additional_inputs=[
72
+ gr.Dropdown(choices=["DialoGPT-med-FT", "DialoGPT-medium"], value="DialoGPT-medium", label="Model"),
73
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
74
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
75
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),