joel-woodfield commited on
Commit
eaae1a3
·
1 Parent(s): bd528c1

Use more informative labels

Browse files
Files changed (1) hide show
  1. gp_visualizer.py +27 -32
gp_visualizer.py CHANGED
@@ -34,14 +34,14 @@ class PlotOptions:
34
  show_training_data: bool = True
35
  show_true_function: bool = True
36
  show_mean_prediction: bool = True
37
- show_confidence_interval: bool = True
38
 
39
  def update(self, **kwargs):
40
  return PlotOptions(
41
  show_training_data=kwargs.get("show_training_data", self.show_training_data),
42
  show_true_function=kwargs.get("show_true_function", self.show_true_function),
43
  show_mean_prediction=kwargs.get("show_mean_prediction", self.show_mean_prediction),
44
- show_confidence_interval=kwargs.get("show_confidence_interval", self.show_confidence_interval),
45
  )
46
 
47
  def __hash__(self):
@@ -50,7 +50,7 @@ class PlotOptions:
50
  self.show_training_data,
51
  self.show_true_function,
52
  self.show_mean_prediction,
53
- self.show_confidence_interval,
54
  )
55
  )
56
 
@@ -90,13 +90,13 @@ def eval_kernel(kernel_str) -> Kernel:
90
  class ModelState:
91
  model: GaussianProcessRegressor
92
  kernel: str
93
- mode: str
94
 
95
  def __hash__(self):
96
  return hash(
97
  (
98
  self.kernel,
99
- self.mode,
100
  )
101
  )
102
 
@@ -181,17 +181,17 @@ class GpVisualizer:
181
  x_test.flatten(),
182
  y_pred,
183
  linestyle="--",
184
- label='prediction',
185
  color=self.plot_cmap(2),
186
  )
187
- if plot_options.show_confidence_interval and x_test is not None and y_std is not None:
188
  plt.fill_between(
189
  x_test.flatten(),
190
  y_pred - 1.96 * y_std,
191
  y_pred + 1.96 * y_std,
192
  color=self.plot_cmap(3),
193
  alpha=0.2,
194
- label='95% confidence interval',
195
  )
196
 
197
  if x_test is not None and sample_y:
@@ -225,14 +225,14 @@ class GpVisualizer:
225
  self,
226
  kernel: str,
227
  dataset: Dataset,
228
- mode: str,
229
  ) -> GaussianProcessRegressor:
230
  model = GaussianProcessRegressor(kernel=eval_kernel(kernel))
231
- if mode == "posterior":
232
  if dataset.x.shape[0] > 0:
233
  model.fit(dataset.x, dataset.y)
234
- elif mode != "prior":
235
- raise ValueError(f"Unknown mode: {mode}")
236
 
237
  return model
238
 
@@ -246,10 +246,10 @@ class GpVisualizer:
246
  model = self.init_model(
247
  model_state.kernel,
248
  dataset,
249
- model_state.mode,
250
  )
251
  model_state = ModelState(
252
- model=model, kernel=model_state.kernel, mode=model_state.mode
253
  )
254
 
255
  new_canvas = self.plot(dataset, model_state, plot_options)
@@ -259,7 +259,7 @@ class GpVisualizer:
259
  def update_model(
260
  self,
261
  kernel_str: str,
262
- mode: str,
263
  model_state: ModelState,
264
  dataset: Dataset,
265
  plot_options: PlotOptions,
@@ -269,10 +269,10 @@ class GpVisualizer:
269
  model = self.init_model(
270
  kernel_str,
271
  dataset,
272
- mode.lower(),
273
  )
274
  model_state = ModelState(
275
- model=model, kernel=kernel_str, mode=mode.lower()
276
  )
277
  except Exception as e:
278
  logger.error(f"Error updating kernel: {e}")
@@ -331,7 +331,7 @@ class GpVisualizer:
331
  kernel = "RBF(length_scale=1.0) + WhiteKernel(noise_level=1.0)"
332
  model = self.init_model(kernel, dataset.value, "posterior")
333
  model_state = gr.State(
334
- ModelState(model=model, kernel=kernel, mode="posterior")
335
  )
336
 
337
  # GUI elements and layout
@@ -358,29 +358,24 @@ class GpVisualizer:
358
  )
359
 
360
  with gr.Tab("Model"):
361
- gr.Textbox(
362
- label="Prior Mean",
363
- value="0",
364
- interactive=False,
365
- )
366
  kernel_box = gr.Textbox(
367
  label="Kernel",
368
  value=model_state.value.kernel,
369
  interactive=True,
370
  )
371
- mode = gr.Radio(
372
- label="Mode",
373
  choices=["Prior", "Posterior"],
374
  value="Posterior",
375
  )
376
  kernel_box.submit(
377
  fn=self.update_model,
378
- inputs=[kernel_box, mode, model_state, dataset, plot_options],
379
  outputs=[model_state, canvas],
380
  )
381
- mode.change(
382
  fn=self.update_model,
383
- inputs=[kernel_box, mode, model_state, dataset, plot_options],
384
  outputs=[model_state, canvas],
385
  )
386
 
@@ -410,7 +405,7 @@ class GpVisualizer:
410
  label="Show Mean Prediction",
411
  value=True,
412
  )
413
- show_confidence_interval = gr.Checkbox(
414
  label="Show Confidence Interval",
415
  value=True,
416
  )
@@ -429,9 +424,9 @@ class GpVisualizer:
429
  inputs=[show_mean_prediction, plot_options],
430
  outputs=[plot_options],
431
  )
432
- show_confidence_interval.change(
433
- fn=lambda val, options: options.update(show_confidence_interval=val),
434
- inputs=[show_confidence_interval, plot_options],
435
  outputs=[plot_options],
436
  )
437
  plot_options.change(
 
34
  show_training_data: bool = True
35
  show_true_function: bool = True
36
  show_mean_prediction: bool = True
37
+ show_prediction_interval: bool = True
38
 
39
  def update(self, **kwargs):
40
  return PlotOptions(
41
  show_training_data=kwargs.get("show_training_data", self.show_training_data),
42
  show_true_function=kwargs.get("show_true_function", self.show_true_function),
43
  show_mean_prediction=kwargs.get("show_mean_prediction", self.show_mean_prediction),
44
+ show_prediction_interval=kwargs.get("show_prediction_interval", self.show_prediction_interval),
45
  )
46
 
47
  def __hash__(self):
 
50
  self.show_training_data,
51
  self.show_true_function,
52
  self.show_mean_prediction,
53
+ self.show_prediction_interval,
54
  )
55
  )
56
 
 
90
  class ModelState:
91
  model: GaussianProcessRegressor
92
  kernel: str
93
+ distribution: str
94
 
95
  def __hash__(self):
96
  return hash(
97
  (
98
  self.kernel,
99
+ self.distribution,
100
  )
101
  )
102
 
 
181
  x_test.flatten(),
182
  y_pred,
183
  linestyle="--",
184
+ label='mean prediction',
185
  color=self.plot_cmap(2),
186
  )
187
+ if plot_options.show_prediction_interval and x_test is not None and y_std is not None:
188
  plt.fill_between(
189
  x_test.flatten(),
190
  y_pred - 1.96 * y_std,
191
  y_pred + 1.96 * y_std,
192
  color=self.plot_cmap(3),
193
  alpha=0.2,
194
+ label='95% prediction interval',
195
  )
196
 
197
  if x_test is not None and sample_y:
 
225
  self,
226
  kernel: str,
227
  dataset: Dataset,
228
+ distribution: str,
229
  ) -> GaussianProcessRegressor:
230
  model = GaussianProcessRegressor(kernel=eval_kernel(kernel))
231
+ if distribution == "posterior":
232
  if dataset.x.shape[0] > 0:
233
  model.fit(dataset.x, dataset.y)
234
+ elif distribution != "prior":
235
+ raise ValueError(f"Unknown distribution: {distribution}")
236
 
237
  return model
238
 
 
246
  model = self.init_model(
247
  model_state.kernel,
248
  dataset,
249
+ model_state.distribution,
250
  )
251
  model_state = ModelState(
252
+ model=model, kernel=model_state.kernel, distribution=model_state.distribution
253
  )
254
 
255
  new_canvas = self.plot(dataset, model_state, plot_options)
 
259
  def update_model(
260
  self,
261
  kernel_str: str,
262
+ distribution: str,
263
  model_state: ModelState,
264
  dataset: Dataset,
265
  plot_options: PlotOptions,
 
269
  model = self.init_model(
270
  kernel_str,
271
  dataset,
272
+ distribution.lower(),
273
  )
274
  model_state = ModelState(
275
+ model=model, kernel=kernel_str, distribution=distribution.lower()
276
  )
277
  except Exception as e:
278
  logger.error(f"Error updating kernel: {e}")
 
331
  kernel = "RBF(length_scale=1.0) + WhiteKernel(noise_level=1.0)"
332
  model = self.init_model(kernel, dataset.value, "posterior")
333
  model_state = gr.State(
334
+ ModelState(model=model, kernel=kernel, distribution="posterior")
335
  )
336
 
337
  # GUI elements and layout
 
358
  )
359
 
360
  with gr.Tab("Model"):
 
 
 
 
 
361
  kernel_box = gr.Textbox(
362
  label="Kernel",
363
  value=model_state.value.kernel,
364
  interactive=True,
365
  )
366
+ distribution = gr.Radio(
367
+ label="Distribution",
368
  choices=["Prior", "Posterior"],
369
  value="Posterior",
370
  )
371
  kernel_box.submit(
372
  fn=self.update_model,
373
+ inputs=[kernel_box, distribution, model_state, dataset, plot_options],
374
  outputs=[model_state, canvas],
375
  )
376
+ distribution.change(
377
  fn=self.update_model,
378
+ inputs=[kernel_box, distribution, model_state, dataset, plot_options],
379
  outputs=[model_state, canvas],
380
  )
381
 
 
405
  label="Show Mean Prediction",
406
  value=True,
407
  )
408
+ show_prediction_interval = gr.Checkbox(
409
  label="Show Confidence Interval",
410
  value=True,
411
  )
 
424
  inputs=[show_mean_prediction, plot_options],
425
  outputs=[plot_options],
426
  )
427
+ show_prediction_interval.change(
428
+ fn=lambda val, options: options.update(show_prediction_interval=val),
429
+ inputs=[show_prediction_interval, plot_options],
430
  outputs=[plot_options],
431
  )
432
  plot_options.change(