kdevoe commited on
Commit
935dbf9
·
verified ·
1 Parent(s): d0fb9fc

Only loading one model at a time, adding Large and XL models

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,20 +1,32 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
- # Load the shared tokenizer (you can use the tokenizer from any Flan-T5 model)
5
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
6
 
7
- # Define the model names
8
  model_names = {
9
  "Flan-T5-small": "google/flan-t5-small",
10
- "Flan-T5-base": "google/flan-t5-base"
 
 
11
  }
12
 
13
- # Pre-load the models
14
- loaded_models = {
15
- model_name: AutoModelForSeq2SeqLM.from_pretrained(model_path)
16
- for model_name, model_path in model_names.items()
17
- }
 
 
 
 
 
 
 
 
 
 
18
 
19
  def respond(
20
  message,
@@ -24,8 +36,8 @@ def respond(
24
  temperature,
25
  top_p,
26
  ):
27
- # Select the pre-loaded model based on user's choice
28
- model = loaded_models[model_choice]
29
 
30
  # Prepare the input by concatenating the history into a dialogue format
31
  input_text = ""
@@ -53,7 +65,11 @@ def respond(
53
  demo = gr.ChatInterface(
54
  respond,
55
  additional_inputs=[
56
- gr.Dropdown(choices=["Flan-T5-small", "Flan-T5-base"], value="Flan-T5-base", label="Model"),
 
 
 
 
57
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
58
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
59
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
@@ -62,6 +78,3 @@ demo = gr.ChatInterface(
62
 
63
  if __name__ == "__main__":
64
  demo.launch()
65
-
66
-
67
-
 
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
+ # Load the shared tokenizer (can be reused across all models)
5
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
6
 
7
+ # Define the available model names and paths
8
  model_names = {
9
  "Flan-T5-small": "google/flan-t5-small",
10
+ "Flan-T5-base": "google/flan-t5-base",
11
+ "Flan-T5-large": "google/flan-t5-large",
12
+ "Flan-T5-XL": "google/flan-t5-xl"
13
  }
14
 
15
+ # Initialize variables to manage loaded model
16
+ current_model = None
17
+ current_model_name = None
18
+
19
+ def load_model(model_name):
20
+ """Load the model if not already loaded or if switching models."""
21
+ global current_model, current_model_name
22
+
23
+ # Load the model only if it hasn't been loaded or if a different one is selected
24
+ if model_name != current_model_name:
25
+ print(f"Loading {model_name}...")
26
+ current_model = AutoModelForSeq2SeqLM.from_pretrained(model_names[model_name])
27
+ current_model_name = model_name
28
+
29
+ return current_model
30
 
31
  def respond(
32
  message,
 
36
  temperature,
37
  top_p,
38
  ):
39
+ # Load the selected model (or switch models if needed)
40
+ model = load_model(model_choice)
41
 
42
  # Prepare the input by concatenating the history into a dialogue format
43
  input_text = ""
 
65
  demo = gr.ChatInterface(
66
  respond,
67
  additional_inputs=[
68
+ gr.Dropdown(
69
+ choices=["Flan-T5-small", "Flan-T5-base", "Flan-T5-large", "Flan-T5-XL"],
70
+ value="Flan-T5-base", # Default selection
71
+ label="Model"
72
+ ),
73
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
74
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
75
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
78
 
79
  if __name__ == "__main__":
80
  demo.launch()