joel-woodfield commited on
Commit
b3a54f6
·
1 Parent(s): d0be49c

Add button to sample from GP

Browse files
Files changed (1) hide show
  1. gp_visualizer.py +30 -0
gp_visualizer.py CHANGED
@@ -144,6 +144,8 @@ class GPVisualizer:
144
  "show_predictions": True,
145
  }
146
 
 
 
147
 
148
  def __init__(self, width, height):
149
  self.canvas_width = width
@@ -222,6 +224,11 @@ class GPVisualizer:
222
  color=self.plot_cmap(3)
223
  )
224
 
 
 
 
 
 
225
  plt.legend()
226
 
227
  buf = io.BytesIO()
@@ -301,6 +308,14 @@ class GPVisualizer:
301
  self.plot_options[key] = value
302
  return self.plot()
303
 
 
 
 
 
 
 
 
 
304
  def launch(self):
305
  # build the Gradio interface
306
  with gr.Blocks(css=self.css) as demo:
@@ -378,6 +393,10 @@ class GPVisualizer:
378
  show_predictions = gr.Checkbox(label="Show mean prediction", value=True)
379
  show_confidence_interval = gr.Checkbox(label="Show confidence interval", value=True)
380
 
 
 
 
 
381
  #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
382
 
383
  with gr.Tab("Export"):
@@ -456,6 +475,17 @@ class GPVisualizer:
456
  outputs=[self.canvas],
457
  )
458
 
 
 
 
 
 
 
 
 
 
 
 
459
  demo.launch()
460
 
461
  visualizer = GPVisualizer(width=1200, height=900)
 
144
  "show_predictions": True,
145
  }
146
 
147
+ self.num_y_samples = 0
148
+
149
 
150
  def __init__(self, width, height):
151
  self.canvas_width = width
 
224
  color=self.plot_cmap(3)
225
  )
226
 
227
+ for i in range(self.num_y_samples):
228
+ y_sample = self.model.sample_y(x_test, random_state=i).flatten()
229
+ plt.plot(x_test.flatten(), y_sample, linestyle=":", label=f"sample {i}", color=self.plot_cmap(4))
230
+
231
+
232
  plt.legend()
233
 
234
  buf = io.BytesIO()
 
308
  self.plot_options[key] = value
309
  return self.plot()
310
 
311
+ def add_y_sample(self):
312
+ self.num_y_samples += 1
313
+ return self.plot()
314
+
315
+ def clear_y_samples(self):
316
+ self.num_y_samples = 0
317
+ return self.plot()
318
+
319
  def launch(self):
320
  # build the Gradio interface
321
  with gr.Blocks(css=self.css) as demo:
 
393
  show_predictions = gr.Checkbox(label="Show mean prediction", value=True)
394
  show_confidence_interval = gr.Checkbox(label="Show confidence interval", value=True)
395
 
396
+ # sampling from GP
397
+ sample_button = gr.Button("Sample from GP")
398
+ clear_samples_button = gr.Button("Clear samples from GP")
399
+
400
  #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
401
 
402
  with gr.Tab("Export"):
 
475
  outputs=[self.canvas],
476
  )
477
 
478
+ # sampling from GP
479
+ sample_button.click(
480
+ fn=self.add_y_sample,
481
+ outputs=[self.canvas],
482
+ )
483
+ clear_samples_button.click(
484
+ fn=self.clear_y_samples,
485
+ outputs=[self.canvas],
486
+ )
487
+
488
+
489
  demo.launch()
490
 
491
  visualizer = GPVisualizer(width=1200, height=900)