Spaces:
Sleeping
Sleeping
Commit ·
09ce41d
1
Parent(s): 7430bec
Add plot options
Browse files- gp_visualizer.py +57 -9
gp_visualizer.py
CHANGED
|
@@ -88,6 +88,13 @@ class GPVisualizer:
|
|
| 88 |
|
| 89 |
self.kernel = eval_kernel(self.DEFAULT_KERNEL)
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
self.css = """
|
| 92 |
.hidden-button {
|
| 93 |
display: none;
|
|
@@ -127,16 +134,21 @@ class GPVisualizer:
|
|
| 127 |
|
| 128 |
R2 = gpr.score(X_tr, y_tr)
|
| 129 |
|
| 130 |
-
if
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
plt.legend()
|
| 142 |
|
|
@@ -190,6 +202,12 @@ class GPVisualizer:
|
|
| 190 |
self.kernel = eval_kernel(kernel_spec)
|
| 191 |
return self.plot()
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
def launch(self):
|
| 194 |
# build the Gradio interface
|
| 195 |
with gr.Blocks(css=self.css) as demo:
|
|
@@ -220,6 +238,13 @@ class GPVisualizer:
|
|
| 220 |
interactive=True,
|
| 221 |
)
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
#gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
|
| 224 |
|
| 225 |
with gr.Tab("Export"):
|
|
@@ -250,10 +275,33 @@ class GPVisualizer:
|
|
| 250 |
kernel_spec.submit(
|
| 251 |
fn=self.update_kernel_spec,
|
| 252 |
inputs=kernel_spec,
|
| 253 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
)
|
| 255 |
|
| 256 |
demo.launch()
|
| 257 |
|
| 258 |
visualizer = GPVisualizer(width=1200, height=900)
|
| 259 |
visualizer.launch()
|
|
|
|
|
|
| 88 |
|
| 89 |
self.kernel = eval_kernel(self.DEFAULT_KERNEL)
|
| 90 |
|
| 91 |
+
self.plot_options = {
|
| 92 |
+
"show_training_data": True,
|
| 93 |
+
"show_confidence_interval": True,
|
| 94 |
+
"show_true_function": True,
|
| 95 |
+
"show_predictions": True,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
self.css = """
|
| 99 |
.hidden-button {
|
| 100 |
display: none;
|
|
|
|
| 134 |
|
| 135 |
R2 = gpr.score(X_tr, y_tr)
|
| 136 |
|
| 137 |
+
if self.plot_options["show_training_data"]:
|
| 138 |
+
if len(X_tr) > 1:
|
| 139 |
+
plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2))
|
| 140 |
+
else:
|
| 141 |
+
plt.scatter(X_tr.flatten(), y_tr, label='training data')
|
| 142 |
+
|
| 143 |
+
if self.plot_options["show_true_function"]:
|
| 144 |
+
plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()), color='red', label='true function')
|
| 145 |
|
| 146 |
+
if self.plot_options["show_predictions"]:
|
| 147 |
+
plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions')
|
| 148 |
|
| 149 |
+
if self.plot_options["show_confidence_interval"]:
|
| 150 |
+
plt.fill_between(X_ts.flatten(), y_pred - 1.96*y_std, y_pred + 1.96*y_std, alpha=0.5,
|
| 151 |
+
label='95% confidence interval')
|
| 152 |
|
| 153 |
plt.legend()
|
| 154 |
|
|
|
|
| 202 |
self.kernel = eval_kernel(kernel_spec)
|
| 203 |
return self.plot()
|
| 204 |
|
| 205 |
+
def update_plot_options(self, **kwargs):
|
| 206 |
+
for key, value in kwargs.items():
|
| 207 |
+
if key in self.plot_options:
|
| 208 |
+
self.plot_options[key] = value
|
| 209 |
+
return self.plot()
|
| 210 |
+
|
| 211 |
def launch(self):
|
| 212 |
# build the Gradio interface
|
| 213 |
with gr.Blocks(css=self.css) as demo:
|
|
|
|
| 238 |
interactive=True,
|
| 239 |
)
|
| 240 |
|
| 241 |
+
# plot show options
|
| 242 |
+
with gr.Group():
|
| 243 |
+
show_training_data = gr.Checkbox(label="Show training data", value=True)
|
| 244 |
+
show_confidence_interval = gr.Checkbox(label="Show confidence interval", value=True)
|
| 245 |
+
show_true_function = gr.Checkbox(label="Show true function", value=True)
|
| 246 |
+
show_predictions = gr.Checkbox(label="Show predictions", value=True)
|
| 247 |
+
|
| 248 |
#gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
|
| 249 |
|
| 250 |
with gr.Tab("Export"):
|
|
|
|
| 275 |
kernel_spec.submit(
|
| 276 |
fn=self.update_kernel_spec,
|
| 277 |
inputs=kernel_spec,
|
| 278 |
+
outputs=[self.canvas],
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# plot options
|
| 282 |
+
show_training_data.change(
|
| 283 |
+
fn=lambda show: self.update_plot_options(show_training_data=show),
|
| 284 |
+
inputs=show_training_data,
|
| 285 |
+
outputs=[self.canvas],
|
| 286 |
+
)
|
| 287 |
+
show_confidence_interval.change(
|
| 288 |
+
fn=lambda show: self.update_plot_options(show_confidence_interval=show),
|
| 289 |
+
inputs=show_confidence_interval,
|
| 290 |
+
outputs=[self.canvas],
|
| 291 |
+
)
|
| 292 |
+
show_true_function.change(
|
| 293 |
+
fn=lambda show: self.update_plot_options(show_true_function=show),
|
| 294 |
+
inputs=show_true_function,
|
| 295 |
+
outputs=[self.canvas],
|
| 296 |
+
)
|
| 297 |
+
show_predictions.change(
|
| 298 |
+
fn=lambda show: self.update_plot_options(show_predictions=show),
|
| 299 |
+
inputs=show_predictions,
|
| 300 |
+
outputs=[self.canvas],
|
| 301 |
)
|
| 302 |
|
| 303 |
demo.launch()
|
| 304 |
|
| 305 |
visualizer = GPVisualizer(width=1200, height=900)
|
| 306 |
visualizer.launch()
|
| 307 |
+
|