joel-woodfield commited on
Commit
09ce41d
·
1 Parent(s): 7430bec

Add plot options

Browse files
Files changed (1) hide show
  1. gp_visualizer.py +57 -9
gp_visualizer.py CHANGED
@@ -88,6 +88,13 @@ class GPVisualizer:
88
 
89
  self.kernel = eval_kernel(self.DEFAULT_KERNEL)
90
 
 
 
 
 
 
 
 
91
  self.css = """
92
  .hidden-button {
93
  display: none;
@@ -127,16 +134,21 @@ class GPVisualizer:
127
 
128
  R2 = gpr.score(X_tr, y_tr)
129
 
130
- if len(X_tr) > 1:
131
- plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2))
132
- else:
133
- plt.scatter(X_tr.flatten(), y_tr, label='training data')
 
 
 
 
134
 
135
- plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()), color='red', label='true function')
136
- plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions')
137
 
138
- plt.fill_between(X_ts.flatten(), y_pred - 1.96*y_std, y_pred + 1.96*y_std, alpha=0.5,
139
- label='95% confidence interval')
 
140
 
141
  plt.legend()
142
 
@@ -190,6 +202,12 @@ class GPVisualizer:
190
  self.kernel = eval_kernel(kernel_spec)
191
  return self.plot()
192
 
 
 
 
 
 
 
193
  def launch(self):
194
  # build the Gradio interface
195
  with gr.Blocks(css=self.css) as demo:
@@ -220,6 +238,13 @@ class GPVisualizer:
220
  interactive=True,
221
  )
222
 
 
 
 
 
 
 
 
223
  #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
224
 
225
  with gr.Tab("Export"):
@@ -250,10 +275,33 @@ class GPVisualizer:
250
  kernel_spec.submit(
251
  fn=self.update_kernel_spec,
252
  inputs=kernel_spec,
253
- outputs=(self.canvas)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  )
255
 
256
  demo.launch()
257
 
258
  visualizer = GPVisualizer(width=1200, height=900)
259
  visualizer.launch()
 
 
88
 
89
  self.kernel = eval_kernel(self.DEFAULT_KERNEL)
90
 
91
+ self.plot_options = {
92
+ "show_training_data": True,
93
+ "show_confidence_interval": True,
94
+ "show_true_function": True,
95
+ "show_predictions": True,
96
+ }
97
+
98
  self.css = """
99
  .hidden-button {
100
  display: none;
 
134
 
135
  R2 = gpr.score(X_tr, y_tr)
136
 
137
+ if self.plot_options["show_training_data"]:
138
+ if len(X_tr) > 1:
139
+ plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2))
140
+ else:
141
+ plt.scatter(X_tr.flatten(), y_tr, label='training data')
142
+
143
+ if self.plot_options["show_true_function"]:
144
+ plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()), color='red', label='true function')
145
 
146
+ if self.plot_options["show_predictions"]:
147
+ plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions')
148
 
149
+ if self.plot_options["show_confidence_interval"]:
150
+ plt.fill_between(X_ts.flatten(), y_pred - 1.96*y_std, y_pred + 1.96*y_std, alpha=0.5,
151
+ label='95% confidence interval')
152
 
153
  plt.legend()
154
 
 
202
  self.kernel = eval_kernel(kernel_spec)
203
  return self.plot()
204
 
205
+ def update_plot_options(self, **kwargs):
206
+ for key, value in kwargs.items():
207
+ if key in self.plot_options:
208
+ self.plot_options[key] = value
209
+ return self.plot()
210
+
211
  def launch(self):
212
  # build the Gradio interface
213
  with gr.Blocks(css=self.css) as demo:
 
238
  interactive=True,
239
  )
240
 
241
+ # plot show options
242
+ with gr.Group():
243
+ show_training_data = gr.Checkbox(label="Show training data", value=True)
244
+ show_confidence_interval = gr.Checkbox(label="Show confidence interval", value=True)
245
+ show_true_function = gr.Checkbox(label="Show true function", value=True)
246
+ show_predictions = gr.Checkbox(label="Show predictions", value=True)
247
+
248
  #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
249
 
250
  with gr.Tab("Export"):
 
275
  kernel_spec.submit(
276
  fn=self.update_kernel_spec,
277
  inputs=kernel_spec,
278
+ outputs=[self.canvas],
279
+ )
280
+
281
+ # plot options
282
+ show_training_data.change(
283
+ fn=lambda show: self.update_plot_options(show_training_data=show),
284
+ inputs=show_training_data,
285
+ outputs=[self.canvas],
286
+ )
287
+ show_confidence_interval.change(
288
+ fn=lambda show: self.update_plot_options(show_confidence_interval=show),
289
+ inputs=show_confidence_interval,
290
+ outputs=[self.canvas],
291
+ )
292
+ show_true_function.change(
293
+ fn=lambda show: self.update_plot_options(show_true_function=show),
294
+ inputs=show_true_function,
295
+ outputs=[self.canvas],
296
+ )
297
+ show_predictions.change(
298
+ fn=lambda show: self.update_plot_options(show_predictions=show),
299
+ inputs=show_predictions,
300
+ outputs=[self.canvas],
301
  )
302
 
303
  demo.launch()
304
 
305
  visualizer = GPVisualizer(width=1200, height=900)
306
  visualizer.launch()
307
+