kdevoe commited on
Commit
2f31341
·
verified ·
1 Parent(s): e08c5ac

Adding model selector for small, medium and large

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -6,12 +6,15 @@ from langchain.memory import ConversationBufferMemory
6
  # Move model to device (GPU if available)
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
 
9
- # Load the tokenizer (you can use the pre-trained tokenizer for GPT-2 family)
10
- tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium")
11
- model = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium")
 
 
 
12
 
13
- # Move model to the device (GPU or CPU)
14
- model.to(device)
15
 
16
  # Set up conversational memory using LangChain's ConversationBufferMemory
17
  memory = ConversationBufferMemory()
@@ -28,7 +31,7 @@ def truncate_history_to_100_tokens(history, tokenizer, max_tokens=100):
28
  return tokenized_history
29
 
30
  # Define the chatbot function with memory and additional parameters
31
- def chat_with_dialogpt(input_text, temperature, top_p, top_k):
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
@@ -45,6 +48,9 @@ def chat_with_dialogpt(input_text, temperature, top_p, top_k):
45
  # Concatenate truncated history with the new input
46
  final_input_ids = torch.cat((truncated_input_ids_tensor, input_ids), dim=1)
47
 
 
 
 
48
  # Generate the response using the model with adjusted parameters
49
  outputs = model.generate(
50
  final_input_ids,
@@ -83,6 +89,9 @@ with gr.Blocks() as interface:
83
 
84
  # Add the instruction message above the input box
85
  gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
 
 
 
86
 
87
  # Add a clear history button
88
  clear_button = gr.Button("Clear History", scale=0)
@@ -97,17 +106,20 @@ with gr.Blocks() as interface:
97
  top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k", scale=0)
98
 
99
  # Define the function to update the chat
100
- def update_chat(input_text, chat_history, temperature, top_p, top_k):
101
- updated_history = chat_with_dialogpt(input_text, temperature, top_p, top_k)
102
  return updated_history, ""
103
 
104
  # Submit when pressing Shift + Enter
105
  user_input.submit(update_chat,
106
- inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider],
107
  outputs=[chatbot_output, user_input])
108
 
109
  # Layout for sliders and chatbot UI
110
  gr.Row([temperature_slider, top_p_slider, top_k_slider])
111
-
 
 
 
112
  # Launch the Gradio app
113
  interface.launch()
 
6
  # Move model to device (GPU if available)
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
 
9
+ # Load all three DialoGPT models (small, medium, large)
10
+ models = {
11
+ "small": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-small").to(device),
12
+ "medium": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium").to(device),
13
+ "large": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-large").to(device)
14
+ }
15
 
16
+ # Load the tokenizer (same tokenizer for all models)
17
+ tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium")
18
 
19
  # Set up conversational memory using LangChain's ConversationBufferMemory
20
  memory = ConversationBufferMemory()
 
31
  return tokenized_history
32
 
33
  # Define the chatbot function with memory and additional parameters
34
+ def chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size):
35
  # Retrieve conversation history
36
  conversation_history = memory.load_memory_variables({})['history']
37
 
 
48
  # Concatenate truncated history with the new input
49
  final_input_ids = torch.cat((truncated_input_ids_tensor, input_ids), dim=1)
50
 
51
+ # Get the model corresponding to the selected size
52
+ model = models[model_size]
53
+
54
  # Generate the response using the model with adjusted parameters
55
  outputs = model.generate(
56
  final_input_ids,
 
89
 
90
  # Add the instruction message above the input box
91
  gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
92
+
93
+ # Add a dropdown for selecting the model size (small, medium, large)
94
+ model_selector = gr.Dropdown(choices=["small", "medium", "large"], value="medium", label="Select Model Size")
95
 
96
  # Add a clear history button
97
  clear_button = gr.Button("Clear History", scale=0)
 
106
  top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k", scale=0)
107
 
108
  # Define the function to update the chat
109
+ def update_chat(input_text, chat_history, temperature, top_p, top_k, model_size):
110
+ updated_history = chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size)
111
  return updated_history, ""
112
 
113
  # Submit when pressing Shift + Enter
114
  user_input.submit(update_chat,
115
+ inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider, model_selector],
116
  outputs=[chatbot_output, user_input])
117
 
118
  # Layout for sliders and chatbot UI
119
  gr.Row([temperature_slider, top_p_slider, top_k_slider])
120
+
121
+ # Layout for model selector and clear button in a row
122
+ gr.Row([model_selector, clear_button])
123
+
124
  # Launch the Gradio app
125
  interface.launch()