Spaces:
Sleeping
Sleeping
Only loading one model at a time, adding Large and XL models
Browse files
app.py
CHANGED
|
@@ -1,20 +1,32 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 3 |
|
| 4 |
-
# Load the shared tokenizer (
|
| 5 |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
| 6 |
|
| 7 |
-
# Define the model names
|
| 8 |
model_names = {
|
| 9 |
"Flan-T5-small": "google/flan-t5-small",
|
| 10 |
-
"Flan-T5-base": "google/flan-t5-base"
|
|
|
|
|
|
|
| 11 |
}
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def respond(
|
| 20 |
message,
|
|
@@ -24,8 +36,8 @@ def respond(
|
|
| 24 |
temperature,
|
| 25 |
top_p,
|
| 26 |
):
|
| 27 |
-
#
|
| 28 |
-
model =
|
| 29 |
|
| 30 |
# Prepare the input by concatenating the history into a dialogue format
|
| 31 |
input_text = ""
|
|
@@ -53,7 +65,11 @@ def respond(
|
|
| 53 |
demo = gr.ChatInterface(
|
| 54 |
respond,
|
| 55 |
additional_inputs=[
|
| 56 |
-
gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 58 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 59 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
|
@@ -62,6 +78,3 @@ demo = gr.ChatInterface(
|
|
| 62 |
|
| 63 |
if __name__ == "__main__":
|
| 64 |
demo.launch()
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 3 |
|
| 4 |
+
# Load the shared tokenizer (can be reused across all models)
|
| 5 |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
| 6 |
|
| 7 |
+
# Define the available model names and paths
|
| 8 |
model_names = {
|
| 9 |
"Flan-T5-small": "google/flan-t5-small",
|
| 10 |
+
"Flan-T5-base": "google/flan-t5-base",
|
| 11 |
+
"Flan-T5-large": "google/flan-t5-large",
|
| 12 |
+
"Flan-T5-XL": "google/flan-t5-xl"
|
| 13 |
}
|
| 14 |
|
| 15 |
+
# Initialize variables to manage loaded model
|
| 16 |
+
current_model = None
|
| 17 |
+
current_model_name = None
|
| 18 |
+
|
| 19 |
+
def load_model(model_name):
|
| 20 |
+
"""Load the model if not already loaded or if switching models."""
|
| 21 |
+
global current_model, current_model_name
|
| 22 |
+
|
| 23 |
+
# Load the model only if it hasn't been loaded or if a different one is selected
|
| 24 |
+
if model_name != current_model_name:
|
| 25 |
+
print(f"Loading {model_name}...")
|
| 26 |
+
current_model = AutoModelForSeq2SeqLM.from_pretrained(model_names[model_name])
|
| 27 |
+
current_model_name = model_name
|
| 28 |
+
|
| 29 |
+
return current_model
|
| 30 |
|
| 31 |
def respond(
|
| 32 |
message,
|
|
|
|
| 36 |
temperature,
|
| 37 |
top_p,
|
| 38 |
):
|
| 39 |
+
# Load the selected model (or switch models if needed)
|
| 40 |
+
model = load_model(model_choice)
|
| 41 |
|
| 42 |
# Prepare the input by concatenating the history into a dialogue format
|
| 43 |
input_text = ""
|
|
|
|
| 65 |
demo = gr.ChatInterface(
|
| 66 |
respond,
|
| 67 |
additional_inputs=[
|
| 68 |
+
gr.Dropdown(
|
| 69 |
+
choices=["Flan-T5-small", "Flan-T5-base", "Flan-T5-large", "Flan-T5-XL"],
|
| 70 |
+
value="Flan-T5-base", # Default selection
|
| 71 |
+
label="Model"
|
| 72 |
+
),
|
| 73 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 74 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 75 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
|
|
|
| 78 |
|
| 79 |
if __name__ == "__main__":
|
| 80 |
demo.launch()
|
|
|
|
|
|
|
|
|