Spaces:
Sleeping
Sleeping
Commit ·
b3a54f6
1
Parent(s): d0be49c
Add button to sample from GP
Browse files- gp_visualizer.py +30 -0
gp_visualizer.py
CHANGED
|
@@ -144,6 +144,8 @@ class GPVisualizer:
|
|
| 144 |
"show_predictions": True,
|
| 145 |
}
|
| 146 |
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def __init__(self, width, height):
|
| 149 |
self.canvas_width = width
|
|
@@ -222,6 +224,11 @@ class GPVisualizer:
|
|
| 222 |
color=self.plot_cmap(3)
|
| 223 |
)
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
plt.legend()
|
| 226 |
|
| 227 |
buf = io.BytesIO()
|
|
@@ -301,6 +308,14 @@ class GPVisualizer:
|
|
| 301 |
self.plot_options[key] = value
|
| 302 |
return self.plot()
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
def launch(self):
|
| 305 |
# build the Gradio interface
|
| 306 |
with gr.Blocks(css=self.css) as demo:
|
|
@@ -378,6 +393,10 @@ class GPVisualizer:
|
|
| 378 |
show_predictions = gr.Checkbox(label="Show mean prediction", value=True)
|
| 379 |
show_confidence_interval = gr.Checkbox(label="Show confidence interval", value=True)
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
#gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
|
| 382 |
|
| 383 |
with gr.Tab("Export"):
|
|
@@ -456,6 +475,17 @@ class GPVisualizer:
|
|
| 456 |
outputs=[self.canvas],
|
| 457 |
)
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
demo.launch()
|
| 460 |
|
| 461 |
visualizer = GPVisualizer(width=1200, height=900)
|
|
|
|
| 144 |
"show_predictions": True,
|
| 145 |
}
|
| 146 |
|
| 147 |
+
self.num_y_samples = 0
|
| 148 |
+
|
| 149 |
|
| 150 |
def __init__(self, width, height):
|
| 151 |
self.canvas_width = width
|
|
|
|
| 224 |
color=self.plot_cmap(3)
|
| 225 |
)
|
| 226 |
|
| 227 |
+
for i in range(self.num_y_samples):
|
| 228 |
+
y_sample = self.model.sample_y(x_test, random_state=i).flatten()
|
| 229 |
+
plt.plot(x_test.flatten(), y_sample, linestyle=":", label=f"sample {i}", color=self.plot_cmap(4))
|
| 230 |
+
|
| 231 |
+
|
| 232 |
plt.legend()
|
| 233 |
|
| 234 |
buf = io.BytesIO()
|
|
|
|
| 308 |
self.plot_options[key] = value
|
| 309 |
return self.plot()
|
| 310 |
|
| 311 |
+
def add_y_sample(self):
|
| 312 |
+
self.num_y_samples += 1
|
| 313 |
+
return self.plot()
|
| 314 |
+
|
| 315 |
+
def clear_y_samples(self):
|
| 316 |
+
self.num_y_samples = 0
|
| 317 |
+
return self.plot()
|
| 318 |
+
|
| 319 |
def launch(self):
|
| 320 |
# build the Gradio interface
|
| 321 |
with gr.Blocks(css=self.css) as demo:
|
|
|
|
| 393 |
show_predictions = gr.Checkbox(label="Show mean prediction", value=True)
|
| 394 |
show_confidence_interval = gr.Checkbox(label="Show confidence interval", value=True)
|
| 395 |
|
| 396 |
+
# sampling from GP
|
| 397 |
+
sample_button = gr.Button("Sample from GP")
|
| 398 |
+
clear_samples_button = gr.Button("Clear samples from GP")
|
| 399 |
+
|
| 400 |
#gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
|
| 401 |
|
| 402 |
with gr.Tab("Export"):
|
|
|
|
| 475 |
outputs=[self.canvas],
|
| 476 |
)
|
| 477 |
|
| 478 |
+
# sampling from GP
|
| 479 |
+
sample_button.click(
|
| 480 |
+
fn=self.add_y_sample,
|
| 481 |
+
outputs=[self.canvas],
|
| 482 |
+
)
|
| 483 |
+
clear_samples_button.click(
|
| 484 |
+
fn=self.clear_y_samples,
|
| 485 |
+
outputs=[self.canvas],
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
demo.launch()
|
| 490 |
|
| 491 |
visualizer = GPVisualizer(width=1200, height=900)
|