kdevoe commited on
Commit
ee87074
·
verified ·
1 Parent(s): d51dabc

Updating to new Gradio chat interface

Browse files
Files changed (1) hide show
  1. app.py +51 -110
app.py CHANGED
@@ -1,123 +1,64 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
- from langchain.memory import ConversationBufferMemory
5
 
6
- # Move model to device (GPU if available)
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
- # Load all three DialoGPT models (small, medium, large)
10
- models = {
11
- "small": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small").to(device),
12
- "medium": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium").to(device),
13
- "large": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large").to(device)
14
  }
15
 
16
- # Load the tokenizer (same tokenizer for all models)
17
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
 
 
 
18
 
19
- # Initialize conversation history
20
- conversation_history = []
 
 
 
 
 
 
 
 
21
 
22
- # Function to clear the chat history
23
- def clear_history():
24
- global conversation_history
25
- conversation_history = []
26
- return ""
27
 
28
- # Define the chatbot function with memory and additional parameters
29
- def chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size):
30
- global conversation_history
31
 
32
- # Encode the user input and append the end-of-text token
33
- new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
34
-
35
- # Append the user input to the conversation history
36
- conversation_history.append(new_user_input_ids)
37
-
38
- # Concatenate conversation history
39
- bot_input_ids = torch.cat(conversation_history, dim=-1)
40
-
41
- # Truncate input_ids to the last 100 tokens if necessary
42
- max_length = 100
43
- if bot_input_ids.size(-1) > max_length:
44
- bot_input_ids = bot_input_ids[:, -max_length:]
45
-
46
- # Get the model corresponding to the selected size
47
- model = models[model_size]
48
-
49
- # Generate a response
50
- response_ids = model.generate(
51
- bot_input_ids,
52
- max_length=bot_input_ids.shape[-1] + 50,
53
- pad_token_id=tokenizer.eos_token_id,
54
- eos_token_id=tokenizer.eos_token_id,
55
  temperature=temperature,
56
  top_p=top_p,
57
- top_k=top_k,
58
- no_repeat_ngram_size=3,
59
- repetition_penalty=1.2,
60
- early_stopping=True,
61
  )
62
-
63
- # Extract only the new tokens generated
64
- new_response_ids = response_ids[:, bot_input_ids.shape[-1]:]
65
-
66
- # Decode the response
67
- response = tokenizer.decode(new_response_ids[0], skip_special_tokens=True)
68
-
69
- # Append the bot response to the conversation history
70
- conversation_history.append(new_response_ids)
71
-
72
- # Format the chat history for display
73
- # For display purposes, reconstruct the conversation
74
- display_conversation = ""
75
- for i in range(0, len(conversation_history), 2):
76
- user_input = tokenizer.decode(conversation_history[i], skip_special_tokens=True)
77
- display_conversation += f"You: {user_input}\n"
78
- if i+1 < len(conversation_history):
79
- bot_response = tokenizer.decode(conversation_history[i+1], skip_special_tokens=True)
80
- display_conversation += f"Bot: {bot_response}\n"
81
-
82
- return display_conversation
83
 
84
- # Set up the Gradio interface
85
- with gr.Blocks() as interface:
86
- chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
87
-
88
- # Add the instruction message above the input box
89
- gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
90
-
91
- # Add a dropdown for selecting the model size (small, medium, large)
92
- model_selector = gr.Dropdown(choices=["small", "medium", "large"], value="medium", label="Select Model Size")
93
-
94
- # Add a clear history button
95
- clear_button = gr.Button("Clear History")
96
- clear_button.click(fn=clear_history, outputs=[chatbot_output])
97
-
98
- # Input box for the user
99
- user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
100
-
101
- # Sliders for temperature, top_p, and top_k
102
- temperature_slider = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Temperature")
103
- top_p_slider = gr.Slider(0.0, 1.0, step=0.1, value=1.0, label="Top-p")
104
- top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
105
-
106
- # Define the function to update the chat
107
- def update_chat(input_text, temperature, top_p, top_k, model_size):
108
- updated_history = chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size)
109
- return updated_history, ""
110
-
111
- # Submit when pressing Shift + Enter
112
- user_input.submit(update_chat,
113
- inputs=[user_input, temperature_slider, top_p_slider, top_k_slider, model_selector],
114
- outputs=[chatbot_output, user_input])
115
-
116
- # Layout for sliders and chatbot UI
117
- gr.Row([temperature_slider, top_p_slider, top_k_slider])
118
-
119
- # Layout for model selector and clear button in a row
120
- gr.Row([model_selector, clear_button])
121
-
122
- # Launch the Gradio app
123
- interface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
3
 
4
+ # Load the shared tokenizer (using a tokenizer from DialoGPT models)
5
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
6
 
7
+ # Define the model names
8
+ model_names = {
9
+ "DialoGPT-small": "microsoft/DialoGPT-small",
10
+ "DialoGPT-medium": "microsoft/DialoGPT-medium"
 
11
  }
12
 
13
+ # Pre-load the models
14
+ loaded_models = {
15
+ model_name: AutoModelForCausalLM.from_pretrained(model_path)
16
+ for model_name, model_path in model_names.items()
17
+ }
18
 
19
+ def respond(
20
+ message,
21
+ history: list[tuple[str, str]],
22
+ model_choice,
23
+ max_tokens,
24
+ temperature,
25
+ top_p,
26
+ ):
27
+ # Select the pre-loaded model based on user's choice
28
+ model = loaded_models[model_choice]
29
 
30
+ # Prepare the input by concatenating the history into a dialogue format
31
+ input_text = ""
32
+ for user_msg, bot_msg in history:
33
+ input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
34
+ input_text += f"User: {message}\nAssistant:"
35
 
36
+ # Tokenize the input text using the shared tokenizer
37
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
 
38
 
39
+ # Generate the response using the selected DialoGPT model
40
+ output_tokens = model.generate(
41
+ inputs["input_ids"],
42
+ max_length=len(inputs["input_ids"][0]) + max_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  temperature=temperature,
44
  top_p=top_p,
45
+ do_sample=True,
 
 
 
46
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Decode and return the assistant's response
49
+ response = tokenizer.decode(output_tokens[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
50
+ yield response
51
+
52
+ # Define the Gradio interface
53
+ demo = gr.ChatInterface(
54
+ respond,
55
+ additional_inputs=[
56
+ gr.Dropdown(choices=["DialoGPT-small", "DialoGPT-medium"], value="DialoGPT-small", label="Model"),
57
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
58
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
59
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
60
+ ],
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()