Adding model selector for small, medium and large
Browse files
app.py
CHANGED
|
@@ -6,12 +6,15 @@ from langchain.memory import ConversationBufferMemory
|
|
| 6 |
# Move model to device (GPU if available)
|
| 7 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 8 |
|
| 9 |
-
# Load
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
|
| 16 |
# Set up conversational memory using LangChain's ConversationBufferMemory
|
| 17 |
memory = ConversationBufferMemory()
|
|
@@ -28,7 +31,7 @@ def truncate_history_to_100_tokens(history, tokenizer, max_tokens=100):
|
|
| 28 |
return tokenized_history
|
| 29 |
|
| 30 |
# Define the chatbot function with memory and additional parameters
|
| 31 |
-
def chat_with_dialogpt(input_text, temperature, top_p, top_k):
|
| 32 |
# Retrieve conversation history
|
| 33 |
conversation_history = memory.load_memory_variables({})['history']
|
| 34 |
|
|
@@ -45,6 +48,9 @@ def chat_with_dialogpt(input_text, temperature, top_p, top_k):
|
|
| 45 |
# Concatenate truncated history with the new input
|
| 46 |
final_input_ids = torch.cat((truncated_input_ids_tensor, input_ids), dim=1)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
# Generate the response using the model with adjusted parameters
|
| 49 |
outputs = model.generate(
|
| 50 |
final_input_ids,
|
|
@@ -83,6 +89,9 @@ with gr.Blocks() as interface:
|
|
| 83 |
|
| 84 |
# Add the instruction message above the input box
|
| 85 |
gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Add a clear history button
|
| 88 |
clear_button = gr.Button("Clear History", scale=0)
|
|
@@ -97,17 +106,20 @@ with gr.Blocks() as interface:
|
|
| 97 |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k", scale=0)
|
| 98 |
|
| 99 |
# Define the function to update the chat
|
| 100 |
-
def update_chat(input_text, chat_history, temperature, top_p, top_k):
|
| 101 |
-
updated_history = chat_with_dialogpt(input_text, temperature, top_p, top_k)
|
| 102 |
return updated_history, ""
|
| 103 |
|
| 104 |
# Submit when pressing Shift + Enter
|
| 105 |
user_input.submit(update_chat,
|
| 106 |
-
inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider],
|
| 107 |
outputs=[chatbot_output, user_input])
|
| 108 |
|
| 109 |
# Layout for sliders and chatbot UI
|
| 110 |
gr.Row([temperature_slider, top_p_slider, top_k_slider])
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
| 112 |
# Launch the Gradio app
|
| 113 |
interface.launch()
|
|
|
|
| 6 |
# Move model to device (GPU if available)
|
| 7 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 8 |
|
| 9 |
+
# Load all three DialoGPT models (small, medium, large)
|
| 10 |
+
models = {
|
| 11 |
+
"small": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-small").to(device),
|
| 12 |
+
"medium": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium").to(device),
|
| 13 |
+
"large": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-large").to(device)
|
| 14 |
+
}
|
| 15 |
|
| 16 |
+
# Load the tokenizer (same tokenizer for all models)
|
| 17 |
+
tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
| 18 |
|
| 19 |
# Set up conversational memory using LangChain's ConversationBufferMemory
|
| 20 |
memory = ConversationBufferMemory()
|
|
|
|
| 31 |
return tokenized_history
|
| 32 |
|
| 33 |
# Define the chatbot function with memory and additional parameters
|
| 34 |
+
def chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size):
|
| 35 |
# Retrieve conversation history
|
| 36 |
conversation_history = memory.load_memory_variables({})['history']
|
| 37 |
|
|
|
|
| 48 |
# Concatenate truncated history with the new input
|
| 49 |
final_input_ids = torch.cat((truncated_input_ids_tensor, input_ids), dim=1)
|
| 50 |
|
| 51 |
+
# Get the model corresponding to the selected size
|
| 52 |
+
model = models[model_size]
|
| 53 |
+
|
| 54 |
# Generate the response using the model with adjusted parameters
|
| 55 |
outputs = model.generate(
|
| 56 |
final_input_ids,
|
|
|
|
| 89 |
|
| 90 |
# Add the instruction message above the input box
|
| 91 |
gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
|
| 92 |
+
|
| 93 |
+
# Add a dropdown for selecting the model size (small, medium, large)
|
| 94 |
+
model_selector = gr.Dropdown(choices=["small", "medium", "large"], value="medium", label="Select Model Size")
|
| 95 |
|
| 96 |
# Add a clear history button
|
| 97 |
clear_button = gr.Button("Clear History", scale=0)
|
|
|
|
| 106 |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k", scale=0)
|
| 107 |
|
| 108 |
# Define the function to update the chat
|
| 109 |
+
def update_chat(input_text, chat_history, temperature, top_p, top_k, model_size):
|
| 110 |
+
updated_history = chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size)
|
| 111 |
return updated_history, ""
|
| 112 |
|
| 113 |
# Submit when pressing Shift + Enter
|
| 114 |
user_input.submit(update_chat,
|
| 115 |
+
inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider, model_selector],
|
| 116 |
outputs=[chatbot_output, user_input])
|
| 117 |
|
| 118 |
# Layout for sliders and chatbot UI
|
| 119 |
gr.Row([temperature_slider, top_p_slider, top_k_slider])
|
| 120 |
+
|
| 121 |
+
# Layout for model selector and clear button in a row
|
| 122 |
+
gr.Row([model_selector, clear_button])
|
| 123 |
+
|
| 124 |
# Launch the Gradio app
|
| 125 |
interface.launch()
|