from dataclasses import dataclass import time import ast import gradio as gr import io import matplotlib.pyplot as plt import numpy as np from PIL import Image from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ( DotProduct, WhiteKernel, ConstantKernel, RBF, Matern, RationalQuadratic, ExpSineSquared, Kernel, ) import logging logging.basicConfig( level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) format="%(asctime)s [%(levelname)s] %(message)s", # log format ) logger = logging.getLogger("ELVIS") from dataset import Dataset, DatasetView, get_function @dataclass(frozen=True) class PlotOptions: show_training_data: bool = True show_true_function: bool = True show_mean_prediction: bool = True show_prediction_interval: bool = True def update(self, **kwargs): return PlotOptions( show_training_data=kwargs.get("show_training_data", self.show_training_data), show_true_function=kwargs.get("show_true_function", self.show_true_function), show_mean_prediction=kwargs.get("show_mean_prediction", self.show_mean_prediction), show_prediction_interval=kwargs.get("show_prediction_interval", self.show_prediction_interval), ) def __hash__(self): return hash( ( self.show_training_data, self.show_true_function, self.show_mean_prediction, self.show_prediction_interval, ) ) def eval_kernel(kernel_str) -> Kernel: # List of allowed kernel constructors allowed_names = { 'RBF': RBF, 'Matern': Matern, 'RationalQuadratic': RationalQuadratic, 'ExpSineSquared': ExpSineSquared, 'DotProduct': DotProduct, 'WhiteKernel': WhiteKernel, 'ConstantKernel': ConstantKernel, } # Parse and check the syntax safely try: tree = ast.parse(kernel_str, mode='eval') except SyntaxError as e: raise ValueError(f"Invalid syntax: {e}") # Evaluate in restricted namespace try: result = eval( compile(tree, '', 'eval'), {"__builtins__": None}, # disable access to Python builtins like open allowed_names # only allow things in this list ) except Exception as e: raise ValueError(f"Error evaluating kernel: {e}") return result @dataclass class ModelState: model: GaussianProcessRegressor kernel: str distribution: str def __hash__(self): return hash( ( self.kernel, self.distribution, ) ) class GpVisualizer: def __init__(self, width, height): self.canvas_width = width self.canvas_height = height self.plot_cmap = plt.get_cmap("tab20") self.css = """ .hidden-button { display: none; }""" def plot( self, dataset: Dataset, model_state: ModelState, plot_options: PlotOptions, sample_y: bool = False, sample_y_seed: int = 0, ) -> Image.Image: print("Plotting") t1 = time.time() fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100) # set entire figure to be the canvas to allow simple conversion of mouse # position to coordinates in the figure ax = fig.add_axes([0., 0., 1., 1.]) # ax.margins(x=0, y=0) # no padding in both directions x_train = dataset.x y_train = dataset.y if dataset.mode == "generate": x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100) y_pred, y_std = model_state.model.predict(x_test, return_std=True) elif x_train.shape[0] > 0: x_test = np.linspace(x_train.min() - 1, x_train.max() + 1, 100).reshape(-1, 1) y_test = None y_pred, y_std = model_state.model.predict(x_test, return_std=True) else: x_test = None y_test = None y_pred = None y_std = None # plot fig, ax = plt.subplots(figsize=(8, 8)) ax.set_title("") ax.set_xlabel("x") ax.set_ylabel("y") if y_test is not None: min_y = min(y_test.min(), (y_pred - 1.96 * y_std).min()) max_y = max(y_test.max(), (y_pred + 1.96 * y_std).max()) ax.set_ylim(min_y - 1, max_y + 1) elif y_train.shape[0] > 0: min_y = min(y_train.min(), (y_pred - 1.96 * y_std).min()) max_y = max(y_train.max(), (y_pred + 1.96 * y_std).max()) ax.set_ylim(min_y - 1, max_y + 1) if plot_options.show_training_data: plt.scatter( x_train.flatten(), y_train, label='training data', color=self.plot_cmap(0), ) if plot_options.show_true_function and x_test is not None and y_test is not None: plt.plot( x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1), ) if plot_options.show_mean_prediction and x_test is not None and y_pred is not None: plt.plot( x_test.flatten(), y_pred, linestyle="--", label='mean prediction', color=self.plot_cmap(2), ) if plot_options.show_prediction_interval and x_test is not None and y_std is not None: plt.fill_between( x_test.flatten(), y_pred - 1.96 * y_std, y_pred + 1.96 * y_std, color=self.plot_cmap(3), alpha=0.2, label='95% prediction interval', ) if x_test is not None and sample_y: y_sample = model_state.model.sample_y( x_test, random_state=sample_y_seed ).flatten() plt.plot( x_test.flatten(), y_sample, linestyle=":", label="model sample", color=self.plot_cmap(4), ) plt.legend() buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) img = Image.open(buf) plt.close(fig) t2 = time.time() logger.info(f"Plotting took {t2 - t1:.4f} seconds") return img def init_model( self, kernel: str, dataset: Dataset, distribution: str, ) -> GaussianProcessRegressor: model = GaussianProcessRegressor(kernel=eval_kernel(kernel)) if distribution == "posterior": if dataset.x.shape[0] > 0: model.fit(dataset.x, dataset.y) elif distribution != "prior": raise ValueError(f"Unknown distribution: {distribution}") return model def update_dataset( self, dataset: Dataset, model_state: ModelState, plot_options: PlotOptions, ) -> tuple[ModelState, Image.Image]: print("updating dataset") model = self.init_model( model_state.kernel, dataset, model_state.distribution, ) model_state = ModelState( model=model, kernel=model_state.kernel, distribution=model_state.distribution ) new_canvas = self.plot(dataset, model_state, plot_options) return model_state, new_canvas def update_model( self, kernel_str: str, distribution: str, model_state: ModelState, dataset: Dataset, plot_options: PlotOptions, ) -> tuple[ModelState, Image.Image]: print("updating kernel") try: model = self.init_model( kernel_str, dataset, distribution.lower(), ) model_state = ModelState( model=model, kernel=kernel_str, distribution=distribution.lower() ) except Exception as e: logger.error(f"Error updating kernel: {e}") gr.Info(f" ⚠️ Error updating kerne: {e}") new_canvas = self.plot(dataset, model_state, plot_options) return model_state, new_canvas def sample( self, model_state: ModelState, dataset: Dataset, plot_options: PlotOptions, ) -> Image.Image: print("sampling from model") seed = int(time.time() * 100) % 10000 new_canvas = self.plot( dataset, model_state, plot_options, sample_y=True, sample_y_seed=seed, ) return new_canvas def clear_sample( self, model_state: ModelState, dataset: Dataset, plot_options: PlotOptions, ) -> Image.Image: print("clearing sample from model") new_canvas = self.plot( dataset, model_state, plot_options, sample_y=False, ) return new_canvas def launch(self): # build the Gradio interface with gr.Blocks(css=self.css) as demo: # app title gr.HTML("
Gaussian Process Visualizer
") # states dataset = gr.State(Dataset()) plot_options = gr.State(PlotOptions()) kernel = "RBF() + WhiteKernel()" model = self.init_model(kernel, dataset.value, "posterior") model_state = gr.State( ModelState(model=model, kernel=kernel, distribution="posterior") ) # GUI elements and layout with gr.Row(): with gr.Column(scale=2): canvas = gr.Image( value=self.plot( dataset.value, model_state.value, plot_options.value, ), # show_download_button=False, container=True, ) with gr.Column(scale=1): with gr.Tab("Dataset"): dataset_view = DatasetView() dataset_view.build(state=dataset) dataset.change( fn=self.update_dataset, inputs=[dataset, model_state, plot_options], outputs=[model_state, canvas], ) with gr.Tab("Model"): kernel_box = gr.Textbox( label="Kernel", value=model_state.value.kernel, interactive=True, ) kernel_submit = gr.Button("Update Kernel") distribution = gr.Radio( label="Distribution", choices=["Prior", "Posterior"], value="Posterior", ) kernel_box.submit( fn=self.update_model, inputs=[kernel_box, distribution, model_state, dataset, plot_options], outputs=[model_state, canvas], ) kernel_submit.click( fn=self.update_model, inputs=[kernel_box, distribution, model_state, dataset, plot_options], outputs=[model_state, canvas], ) distribution.change( fn=self.update_model, inputs=[kernel_box, distribution, model_state, dataset, plot_options], outputs=[model_state, canvas], ) sample_button = gr.Button("Sample") clear_sample_button = gr.Button("Clear Sample") sample_button.click( fn=self.sample, inputs=[model_state, dataset, plot_options], outputs=[canvas], ) clear_sample_button.click( fn=self.clear_sample, inputs=[model_state, dataset, plot_options], outputs=[canvas], ) with gr.Tab("Plot Options"): show_training_data = gr.Checkbox( label="Show Training Data", value=True, ) show_true_function = gr.Checkbox( label="Show True Function", value=True, ) show_mean_prediction = gr.Checkbox( label="Show Mean Prediction", value=True, ) show_prediction_interval = gr.Checkbox( label="Show Prediction Interval", value=True, ) show_training_data.change( fn=lambda val, options: options.update(show_training_data=val), inputs=[show_training_data, plot_options], outputs=[plot_options], ) show_true_function.change( fn=lambda val, options: options.update(show_true_function=val), inputs=[show_true_function, plot_options], outputs=[plot_options], ) show_mean_prediction.change( fn=lambda val, options: options.update(show_mean_prediction=val), inputs=[show_mean_prediction, plot_options], outputs=[plot_options], ) show_prediction_interval.change( fn=lambda val, options: options.update(show_prediction_interval=val), inputs=[show_prediction_interval, plot_options], outputs=[plot_options], ) plot_options.change( fn=self.plot, inputs=[dataset, model_state, plot_options], outputs=[canvas], ) with gr.Tab("Usage"): with open("usage.md", "r") as f: usage_md = f.read() gr.Markdown(usage_md) demo.launch() visualizer = GpVisualizer(width=1200, height=900) visualizer.launch()