aadya1762 commited on
Commit
8a9dfc3
·
1 Parent(s): 87122c2

bug fixes

Browse files
Files changed (1) hide show
  1. gemmademo/_chat.py +110 -158
gemmademo/_chat.py CHANGED
@@ -101,176 +101,128 @@ class GradioChat:
101
  return gr.Dataset(samples=[[example] for example in examples])
102
 
103
  with gr.Blocks() as demo:
104
- with gr.Tab("Model Playground"):
105
- with gr.Row():
106
- with gr.Column(scale=3): # Sidebar column
107
- with gr.Accordion(
108
- "Basic Settings ⚙️", open=False
109
- ): # Make the sidebar foldable
110
- gr.Markdown(
111
- "## Google Gemma Models: lightweight, state-of-the-art open models from Google"
112
- )
113
- task_dropdown = gr.Dropdown(
114
- choices=self.task_options,
115
- value=self.current_task_name,
116
- label="Select Task",
117
- )
118
- model_dropdown = gr.Dropdown(
119
- choices=self.model_options,
120
- value=self.current_model_name,
121
- label="Select Gemma Model",
122
- )
123
- chat_interface = gr.ChatInterface(
124
- chat_fn,
125
- additional_inputs=[model_dropdown, task_dropdown],
126
- textbox=gr.Textbox(
127
- placeholder="Ask me something...", container=False
128
- ),
129
- )
130
-
131
- with gr.Column(scale=1):
132
  gr.Markdown(
133
- """
134
- ## Tips
135
-
136
- - First response after model change will be slower (model loading lazily).
137
- - Switching models clears chat history.
138
- - Larger models need more memory but give better results.
139
- """
140
  )
141
- examples_list = gr.Examples(
142
- examples=[
143
- [example]
144
- for example in _get_examples(self.current_task_name)
145
- ],
146
- inputs=chat_interface.textbox,
147
  )
148
- task_dropdown.change(
149
- _update_examples, task_dropdown, examples_list.dataset
 
 
150
  )
151
- with gr.Accordion("Model Configuration ⚙️", open=False):
152
- temperature_slider = gr.Slider(
153
- minimum=0.1,
154
- maximum=2,
155
- value=self.model.temperature,
156
- label="Temperature",
157
- )
158
- gr.Markdown(
159
- "**Temperature:** Lower values make the output more deterministic."
160
- )
161
- temperature_slider.change(
162
- fn=lambda temp: setattr(
163
- self.model, "temperature", temp
164
- ),
165
- inputs=temperature_slider,
166
- )
167
-
168
- top_p_slider = gr.Slider(
169
- minimum=0.1,
170
- maximum=1.0,
171
- value=self.model.top_p,
172
- label="Top P",
173
- )
174
- gr.Markdown(
175
- "**Top P:** Lower values make the output more focused."
176
- )
177
- top_p_slider.change(
178
- fn=lambda top_p: setattr(self.model, "top_p", top_p),
179
- inputs=top_p_slider,
180
- )
181
-
182
- top_k_slider = gr.Slider(
183
- minimum=1,
184
- maximum=100,
185
- value=self.model.top_k,
186
- label="Top K",
187
- )
188
- gr.Markdown(
189
- "**Top K:** Lower values make the output more focused."
190
- )
191
- top_k_slider.change(
192
- fn=lambda top_k: setattr(self.model, "top_k", top_k),
193
- inputs=top_k_slider,
194
- )
195
-
196
- repetition_penalty_slider = gr.Slider(
197
- minimum=1.0,
198
- maximum=2.0,
199
- value=self.model.repeat_penalty,
200
- label="Repetition Penalty",
201
- )
202
- gr.Markdown(
203
- "**Repetition Penalty:** Penalizes repeated tokens to reduce repetition in the output."
204
- )
205
- repetition_penalty_slider.change(
206
- fn=lambda penalty: setattr(
207
- self.model, "repeat_penalty", penalty
208
- ),
209
- inputs=repetition_penalty_slider,
210
- )
211
-
212
- max_tokens_slider = gr.Slider(
213
- minimum=512,
214
- maximum=2048,
215
- value=self.model.max_tokens,
216
- label="Max Tokens",
217
- )
218
- gr.Markdown(
219
- "**Max Tokens:** Sets the maximum number of tokens the model can generate in one response."
220
- )
221
- max_tokens_slider.change(
222
- fn=lambda max_tokens: setattr(
223
- self.model, "max_tokens", max_tokens
224
- ),
225
- inputs=max_tokens_slider,
226
- )
227
-
228
- with gr.Tab("Model Comparision"):
229
- with gr.Row():
230
- # Input for user query
231
- user_input = gr.Textbox(
232
- placeholder="Enter your query here...", label="User Input"
233
  )
234
 
235
- # Dropdown for model selection
236
- model_comparison_dropdown = gr.Dropdown(
237
- choices=self.model_options,
238
- label="Select Models",
239
- multiselect=True, # Allow multiple selections
240
- value=[self.current_model_name], # Default to current model
 
 
 
241
  )
242
-
243
- # Create output textboxes for each model
244
- output_textboxes = {}
245
- for model_name in self.model_options:
246
- output_textboxes[model_name] = gr.Textbox(
247
- label=model_name, interactive=False
 
 
 
248
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- # Button to trigger comparison
251
- compare_button = gr.Button("Compare Models")
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- def compare_models(user_input, selected_models):
254
- responses = []
255
- for model_name in selected_models:
256
- model = self._load_model(model_name) # Load each selected model
257
- prompt = self.prompt_manager.get_prompt(user_input=user_input)
258
- response = model.generate_response(prompt)
259
- responses.append(response) # Store response
260
- return responses # Return list of responses
 
 
 
 
 
261
 
262
- compare_button.click(
263
- fn=compare_models,
264
- inputs=[user_input, model_comparison_dropdown],
265
- outputs=list(output_textboxes.values()), # Output to textboxes
266
- )
 
 
 
 
 
 
 
 
 
 
267
 
268
- # Display responses for each model
269
- with gr.Row():
270
- for model_name, output_box in output_textboxes.items():
271
- with gr.Column():
272
- gr.Markdown(f"### Output from {model_name}:")
273
- output_box # Add the output textbox to the layout
 
 
 
 
 
 
 
 
 
274
 
275
  demo.launch()
276
 
 
101
  return gr.Dataset(samples=[[example] for example in examples])
102
 
103
  with gr.Blocks() as demo:
104
+ with gr.Row():
105
+ with gr.Column(scale=3): # Sidebar column
106
+ with gr.Accordion(
107
+ "Basic Settings ⚙️", open=False
108
+ ): # Make the sidebar foldable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  gr.Markdown(
110
+ "## Google Gemma Models: lightweight, state-of-the-art open models from Google"
 
 
 
 
 
 
111
  )
112
+ task_dropdown = gr.Dropdown(
113
+ choices=self.task_options,
114
+ value=self.current_task_name,
115
+ label="Select Task",
 
 
116
  )
117
+ model_dropdown = gr.Dropdown(
118
+ choices=self.model_options,
119
+ value=self.current_model_name,
120
+ label="Select Gemma Model",
121
  )
122
+ chat_interface = gr.ChatInterface(
123
+ chat_fn,
124
+ additional_inputs=[model_dropdown, task_dropdown],
125
+ textbox=gr.Textbox(
126
+ placeholder="Ask me something...", container=False
127
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  )
129
 
130
+ with gr.Column(scale=1):
131
+ gr.Markdown(
132
+ """
133
+ ## Tips
134
+
135
+ - First response after model change will be slower (model loading lazily).
136
+ - Switching models clears chat history.
137
+ - Larger models need more memory but give better results.
138
+ """
139
  )
140
+ examples_list = gr.Examples(
141
+ examples=[
142
+ [example]
143
+ for example in _get_examples(self.current_task_name)
144
+ ],
145
+ inputs=chat_interface.textbox,
146
+ )
147
+ task_dropdown.change(
148
+ _update_examples, task_dropdown, examples_list.dataset
149
  )
150
+ with gr.Accordion("Model Configuration ⚙️", open=False):
151
+ temperature_slider = gr.Slider(
152
+ minimum=0.1,
153
+ maximum=2,
154
+ value=self.model.temperature,
155
+ label="Temperature",
156
+ )
157
+ gr.Markdown(
158
+ "**Temperature:** Lower values make the output more deterministic."
159
+ )
160
+ temperature_slider.change(
161
+ fn=lambda temp: setattr(
162
+ self.model, "temperature", temp
163
+ ),
164
+ inputs=temperature_slider,
165
+ )
166
 
167
+ top_p_slider = gr.Slider(
168
+ minimum=0.1,
169
+ maximum=1.0,
170
+ value=self.model.top_p,
171
+ label="Top P",
172
+ )
173
+ gr.Markdown(
174
+ "**Top P:** Lower values make the output more focused."
175
+ )
176
+ top_p_slider.change(
177
+ fn=lambda top_p: setattr(self.model, "top_p", top_p),
178
+ inputs=top_p_slider,
179
+ )
180
 
181
+ top_k_slider = gr.Slider(
182
+ minimum=1,
183
+ maximum=100,
184
+ value=self.model.top_k,
185
+ label="Top K",
186
+ )
187
+ gr.Markdown(
188
+ "**Top K:** Lower values make the output more focused."
189
+ )
190
+ top_k_slider.change(
191
+ fn=lambda top_k: setattr(self.model, "top_k", top_k),
192
+ inputs=top_k_slider,
193
+ )
194
 
195
+ repetition_penalty_slider = gr.Slider(
196
+ minimum=1.0,
197
+ maximum=2.0,
198
+ value=self.model.repeat_penalty,
199
+ label="Repetition Penalty",
200
+ )
201
+ gr.Markdown(
202
+ "**Repetition Penalty:** Penalizes repeated tokens to reduce repetition in the output."
203
+ )
204
+ repetition_penalty_slider.change(
205
+ fn=lambda penalty: setattr(
206
+ self.model, "repeat_penalty", penalty
207
+ ),
208
+ inputs=repetition_penalty_slider,
209
+ )
210
 
211
+ max_tokens_slider = gr.Slider(
212
+ minimum=512,
213
+ maximum=2048,
214
+ value=self.model.max_tokens,
215
+ label="Max Tokens",
216
+ )
217
+ gr.Markdown(
218
+ "**Max Tokens:** Sets the maximum number of tokens the model can generate in one response."
219
+ )
220
+ max_tokens_slider.change(
221
+ fn=lambda max_tokens: setattr(
222
+ self.model, "max_tokens", max_tokens
223
+ ),
224
+ inputs=max_tokens_slider,
225
+ )
226
 
227
  demo.launch()
228