kdevoe commited on
Commit
65d226c
·
verified ·
1 Parent(s): 8a95d1b

Adding options for FlanT5 small, base and large

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -1,29 +1,43 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
- # Load Flan-T5-base model and tokenizer from Hugging Face
5
- model_name = "google/flan-t5-base"
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def respond(
11
  message,
12
  history: list[tuple[str, str]],
 
13
  max_tokens,
14
  temperature,
15
  top_p,
16
  ):
 
 
 
17
  # Prepare the input by concatenating the history into a dialogue format
18
  input_text = ""
19
  for user_msg, bot_msg in history:
20
  input_text += f"User: {user_msg} Assistant: {bot_msg} "
21
  input_text += f"User: {message}"
22
 
23
- # Tokenize the input text
24
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
25
 
26
- # Generate the response using Flan-T5-base
27
  output_tokens = model.generate(
28
  inputs["input_ids"],
29
  max_length=max_tokens,
@@ -36,11 +50,11 @@ def respond(
36
  response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
37
  yield response
38
 
39
-
40
  # Define the Gradio interface
41
  demo = gr.ChatInterface(
42
  respond,
43
  additional_inputs=[
 
44
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
45
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
46
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
@@ -51,3 +65,4 @@ if __name__ == "__main__":
51
  demo.launch()
52
 
53
 
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
+ # Load the shared tokenizer (you can use the tokenizer from any Flan-T5 model)
5
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
 
 
6
 
7
+ # Define the model names
8
+ model_names = {
9
+ "Flan-T5-small": "google/flan-t5-small",
10
+ "Flan-T5-base": "google/flan-t5-base",
11
+ "Flan-T5-large": "google/flan-t5-large"
12
+ }
13
+
14
+ # Pre-load the models
15
+ loaded_models = {
16
+ model_name: AutoModelForSeq2SeqLM.from_pretrained(model_path)
17
+ for model_name, model_path in model_names.items()
18
+ }
19
 
20
  def respond(
21
  message,
22
  history: list[tuple[str, str]],
23
+ model_choice,
24
  max_tokens,
25
  temperature,
26
  top_p,
27
  ):
28
+ # Select the pre-loaded model based on user's choice
29
+ model = loaded_models[model_choice]
30
+
31
  # Prepare the input by concatenating the history into a dialogue format
32
  input_text = ""
33
  for user_msg, bot_msg in history:
34
  input_text += f"User: {user_msg} Assistant: {bot_msg} "
35
  input_text += f"User: {message}"
36
 
37
+ # Tokenize the input text using the shared tokenizer
38
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
39
 
40
+ # Generate the response using the selected Flan-T5 model
41
  output_tokens = model.generate(
42
  inputs["input_ids"],
43
  max_length=max_tokens,
 
50
  response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
51
  yield response
52
 
 
53
  # Define the Gradio interface
54
  demo = gr.ChatInterface(
55
  respond,
56
  additional_inputs=[
57
+ gr.Dropdown(choices=["Flan-T5-small", "Flan-T5-base", "Flan-T5-large"], value="Flan-T5-base", label="Model"),
58
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
59
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
60
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
65
  demo.launch()
66
 
67
 
68
+