Spaces:
Sleeping
Sleeping
Commit ·
eaae1a3
1
Parent(s): bd528c1
Use more informative labels
Browse files- gp_visualizer.py +27 -32
gp_visualizer.py
CHANGED
|
@@ -34,14 +34,14 @@ class PlotOptions:
|
|
| 34 |
show_training_data: bool = True
|
| 35 |
show_true_function: bool = True
|
| 36 |
show_mean_prediction: bool = True
|
| 37 |
-
|
| 38 |
|
| 39 |
def update(self, **kwargs):
|
| 40 |
return PlotOptions(
|
| 41 |
show_training_data=kwargs.get("show_training_data", self.show_training_data),
|
| 42 |
show_true_function=kwargs.get("show_true_function", self.show_true_function),
|
| 43 |
show_mean_prediction=kwargs.get("show_mean_prediction", self.show_mean_prediction),
|
| 44 |
-
|
| 45 |
)
|
| 46 |
|
| 47 |
def __hash__(self):
|
|
@@ -50,7 +50,7 @@ class PlotOptions:
|
|
| 50 |
self.show_training_data,
|
| 51 |
self.show_true_function,
|
| 52 |
self.show_mean_prediction,
|
| 53 |
-
self.
|
| 54 |
)
|
| 55 |
)
|
| 56 |
|
|
@@ -90,13 +90,13 @@ def eval_kernel(kernel_str) -> Kernel:
|
|
| 90 |
class ModelState:
|
| 91 |
model: GaussianProcessRegressor
|
| 92 |
kernel: str
|
| 93 |
-
|
| 94 |
|
| 95 |
def __hash__(self):
|
| 96 |
return hash(
|
| 97 |
(
|
| 98 |
self.kernel,
|
| 99 |
-
self.
|
| 100 |
)
|
| 101 |
)
|
| 102 |
|
|
@@ -181,17 +181,17 @@ class GpVisualizer:
|
|
| 181 |
x_test.flatten(),
|
| 182 |
y_pred,
|
| 183 |
linestyle="--",
|
| 184 |
-
label='prediction',
|
| 185 |
color=self.plot_cmap(2),
|
| 186 |
)
|
| 187 |
-
if plot_options.
|
| 188 |
plt.fill_between(
|
| 189 |
x_test.flatten(),
|
| 190 |
y_pred - 1.96 * y_std,
|
| 191 |
y_pred + 1.96 * y_std,
|
| 192 |
color=self.plot_cmap(3),
|
| 193 |
alpha=0.2,
|
| 194 |
-
label='95%
|
| 195 |
)
|
| 196 |
|
| 197 |
if x_test is not None and sample_y:
|
|
@@ -225,14 +225,14 @@ class GpVisualizer:
|
|
| 225 |
self,
|
| 226 |
kernel: str,
|
| 227 |
dataset: Dataset,
|
| 228 |
-
|
| 229 |
) -> GaussianProcessRegressor:
|
| 230 |
model = GaussianProcessRegressor(kernel=eval_kernel(kernel))
|
| 231 |
-
if
|
| 232 |
if dataset.x.shape[0] > 0:
|
| 233 |
model.fit(dataset.x, dataset.y)
|
| 234 |
-
elif
|
| 235 |
-
raise ValueError(f"Unknown
|
| 236 |
|
| 237 |
return model
|
| 238 |
|
|
@@ -246,10 +246,10 @@ class GpVisualizer:
|
|
| 246 |
model = self.init_model(
|
| 247 |
model_state.kernel,
|
| 248 |
dataset,
|
| 249 |
-
model_state.
|
| 250 |
)
|
| 251 |
model_state = ModelState(
|
| 252 |
-
model=model, kernel=model_state.kernel,
|
| 253 |
)
|
| 254 |
|
| 255 |
new_canvas = self.plot(dataset, model_state, plot_options)
|
|
@@ -259,7 +259,7 @@ class GpVisualizer:
|
|
| 259 |
def update_model(
|
| 260 |
self,
|
| 261 |
kernel_str: str,
|
| 262 |
-
|
| 263 |
model_state: ModelState,
|
| 264 |
dataset: Dataset,
|
| 265 |
plot_options: PlotOptions,
|
|
@@ -269,10 +269,10 @@ class GpVisualizer:
|
|
| 269 |
model = self.init_model(
|
| 270 |
kernel_str,
|
| 271 |
dataset,
|
| 272 |
-
|
| 273 |
)
|
| 274 |
model_state = ModelState(
|
| 275 |
-
model=model, kernel=kernel_str,
|
| 276 |
)
|
| 277 |
except Exception as e:
|
| 278 |
logger.error(f"Error updating kernel: {e}")
|
|
@@ -331,7 +331,7 @@ class GpVisualizer:
|
|
| 331 |
kernel = "RBF(length_scale=1.0) + WhiteKernel(noise_level=1.0)"
|
| 332 |
model = self.init_model(kernel, dataset.value, "posterior")
|
| 333 |
model_state = gr.State(
|
| 334 |
-
ModelState(model=model, kernel=kernel,
|
| 335 |
)
|
| 336 |
|
| 337 |
# GUI elements and layout
|
|
@@ -358,29 +358,24 @@ class GpVisualizer:
|
|
| 358 |
)
|
| 359 |
|
| 360 |
with gr.Tab("Model"):
|
| 361 |
-
gr.Textbox(
|
| 362 |
-
label="Prior Mean",
|
| 363 |
-
value="0",
|
| 364 |
-
interactive=False,
|
| 365 |
-
)
|
| 366 |
kernel_box = gr.Textbox(
|
| 367 |
label="Kernel",
|
| 368 |
value=model_state.value.kernel,
|
| 369 |
interactive=True,
|
| 370 |
)
|
| 371 |
-
|
| 372 |
-
label="
|
| 373 |
choices=["Prior", "Posterior"],
|
| 374 |
value="Posterior",
|
| 375 |
)
|
| 376 |
kernel_box.submit(
|
| 377 |
fn=self.update_model,
|
| 378 |
-
inputs=[kernel_box,
|
| 379 |
outputs=[model_state, canvas],
|
| 380 |
)
|
| 381 |
-
|
| 382 |
fn=self.update_model,
|
| 383 |
-
inputs=[kernel_box,
|
| 384 |
outputs=[model_state, canvas],
|
| 385 |
)
|
| 386 |
|
|
@@ -410,7 +405,7 @@ class GpVisualizer:
|
|
| 410 |
label="Show Mean Prediction",
|
| 411 |
value=True,
|
| 412 |
)
|
| 413 |
-
|
| 414 |
label="Show Confidence Interval",
|
| 415 |
value=True,
|
| 416 |
)
|
|
@@ -429,9 +424,9 @@ class GpVisualizer:
|
|
| 429 |
inputs=[show_mean_prediction, plot_options],
|
| 430 |
outputs=[plot_options],
|
| 431 |
)
|
| 432 |
-
|
| 433 |
-
fn=lambda val, options: options.update(
|
| 434 |
-
inputs=[
|
| 435 |
outputs=[plot_options],
|
| 436 |
)
|
| 437 |
plot_options.change(
|
|
|
|
| 34 |
show_training_data: bool = True
|
| 35 |
show_true_function: bool = True
|
| 36 |
show_mean_prediction: bool = True
|
| 37 |
+
show_prediction_interval: bool = True
|
| 38 |
|
| 39 |
def update(self, **kwargs):
|
| 40 |
return PlotOptions(
|
| 41 |
show_training_data=kwargs.get("show_training_data", self.show_training_data),
|
| 42 |
show_true_function=kwargs.get("show_true_function", self.show_true_function),
|
| 43 |
show_mean_prediction=kwargs.get("show_mean_prediction", self.show_mean_prediction),
|
| 44 |
+
show_prediction_interval=kwargs.get("show_prediction_interval", self.show_prediction_interval),
|
| 45 |
)
|
| 46 |
|
| 47 |
def __hash__(self):
|
|
|
|
| 50 |
self.show_training_data,
|
| 51 |
self.show_true_function,
|
| 52 |
self.show_mean_prediction,
|
| 53 |
+
self.show_prediction_interval,
|
| 54 |
)
|
| 55 |
)
|
| 56 |
|
|
|
|
| 90 |
class ModelState:
|
| 91 |
model: GaussianProcessRegressor
|
| 92 |
kernel: str
|
| 93 |
+
distribution: str
|
| 94 |
|
| 95 |
def __hash__(self):
|
| 96 |
return hash(
|
| 97 |
(
|
| 98 |
self.kernel,
|
| 99 |
+
self.distribution,
|
| 100 |
)
|
| 101 |
)
|
| 102 |
|
|
|
|
| 181 |
x_test.flatten(),
|
| 182 |
y_pred,
|
| 183 |
linestyle="--",
|
| 184 |
+
label='mean prediction',
|
| 185 |
color=self.plot_cmap(2),
|
| 186 |
)
|
| 187 |
+
if plot_options.show_prediction_interval and x_test is not None and y_std is not None:
|
| 188 |
plt.fill_between(
|
| 189 |
x_test.flatten(),
|
| 190 |
y_pred - 1.96 * y_std,
|
| 191 |
y_pred + 1.96 * y_std,
|
| 192 |
color=self.plot_cmap(3),
|
| 193 |
alpha=0.2,
|
| 194 |
+
label='95% prediction interval',
|
| 195 |
)
|
| 196 |
|
| 197 |
if x_test is not None and sample_y:
|
|
|
|
| 225 |
self,
|
| 226 |
kernel: str,
|
| 227 |
dataset: Dataset,
|
| 228 |
+
distribution: str,
|
| 229 |
) -> GaussianProcessRegressor:
|
| 230 |
model = GaussianProcessRegressor(kernel=eval_kernel(kernel))
|
| 231 |
+
if distribution == "posterior":
|
| 232 |
if dataset.x.shape[0] > 0:
|
| 233 |
model.fit(dataset.x, dataset.y)
|
| 234 |
+
elif distribution != "prior":
|
| 235 |
+
raise ValueError(f"Unknown distribution: {distribution}")
|
| 236 |
|
| 237 |
return model
|
| 238 |
|
|
|
|
| 246 |
model = self.init_model(
|
| 247 |
model_state.kernel,
|
| 248 |
dataset,
|
| 249 |
+
model_state.distribution,
|
| 250 |
)
|
| 251 |
model_state = ModelState(
|
| 252 |
+
model=model, kernel=model_state.kernel, distribution=model_state.distribution
|
| 253 |
)
|
| 254 |
|
| 255 |
new_canvas = self.plot(dataset, model_state, plot_options)
|
|
|
|
| 259 |
def update_model(
|
| 260 |
self,
|
| 261 |
kernel_str: str,
|
| 262 |
+
distribution: str,
|
| 263 |
model_state: ModelState,
|
| 264 |
dataset: Dataset,
|
| 265 |
plot_options: PlotOptions,
|
|
|
|
| 269 |
model = self.init_model(
|
| 270 |
kernel_str,
|
| 271 |
dataset,
|
| 272 |
+
distribution.lower(),
|
| 273 |
)
|
| 274 |
model_state = ModelState(
|
| 275 |
+
model=model, kernel=kernel_str, distribution=distribution.lower()
|
| 276 |
)
|
| 277 |
except Exception as e:
|
| 278 |
logger.error(f"Error updating kernel: {e}")
|
|
|
|
| 331 |
kernel = "RBF(length_scale=1.0) + WhiteKernel(noise_level=1.0)"
|
| 332 |
model = self.init_model(kernel, dataset.value, "posterior")
|
| 333 |
model_state = gr.State(
|
| 334 |
+
ModelState(model=model, kernel=kernel, distribution="posterior")
|
| 335 |
)
|
| 336 |
|
| 337 |
# GUI elements and layout
|
|
|
|
| 358 |
)
|
| 359 |
|
| 360 |
with gr.Tab("Model"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
kernel_box = gr.Textbox(
|
| 362 |
label="Kernel",
|
| 363 |
value=model_state.value.kernel,
|
| 364 |
interactive=True,
|
| 365 |
)
|
| 366 |
+
distribution = gr.Radio(
|
| 367 |
+
label="Distribution",
|
| 368 |
choices=["Prior", "Posterior"],
|
| 369 |
value="Posterior",
|
| 370 |
)
|
| 371 |
kernel_box.submit(
|
| 372 |
fn=self.update_model,
|
| 373 |
+
inputs=[kernel_box, distribution, model_state, dataset, plot_options],
|
| 374 |
outputs=[model_state, canvas],
|
| 375 |
)
|
| 376 |
+
distribution.change(
|
| 377 |
fn=self.update_model,
|
| 378 |
+
inputs=[kernel_box, distribution, model_state, dataset, plot_options],
|
| 379 |
outputs=[model_state, canvas],
|
| 380 |
)
|
| 381 |
|
|
|
|
| 405 |
label="Show Mean Prediction",
|
| 406 |
value=True,
|
| 407 |
)
|
| 408 |
+
show_prediction_interval = gr.Checkbox(
|
| 409 |
label="Show Confidence Interval",
|
| 410 |
value=True,
|
| 411 |
)
|
|
|
|
| 424 |
inputs=[show_mean_prediction, plot_options],
|
| 425 |
outputs=[plot_options],
|
| 426 |
)
|
| 427 |
+
show_prediction_interval.change(
|
| 428 |
+
fn=lambda val, options: options.update(show_prediction_interval=val),
|
| 429 |
+
inputs=[show_prediction_interval, plot_options],
|
| 430 |
outputs=[plot_options],
|
| 431 |
)
|
| 432 |
plot_options.change(
|