Spaces:
Sleeping
Sleeping
| from collections import deque | |
| from pathlib import Path | |
| import pickle | |
| import gradio as gr | |
| import inspect | |
| import io | |
| from jinja2 import Template | |
| import matplotlib.pyplot as plt | |
| import matplotlib.lines as mlines | |
| import numpy as np | |
| import numexpr | |
| import pandas as pd | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| import sklearn | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.svm import LinearSVC | |
| from sklearn.datasets import load_iris | |
| from sklearn.metrics import classification_report, mean_squared_error, mean_absolute_error | |
| from sklearn.datasets import make_regression | |
| from sklearn.linear_model import ElasticNet | |
| import ast | |
| import traceback | |
| import yaml | |
| from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, Normalizer | |
| from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier | |
| from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel, ConstantKernel, RBF, Matern, RationalQuadratic, ExpSineSquared | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) | |
| format="%(asctime)s [%(levelname)s] %(message)s", # log format | |
| ) | |
| logger = logging.getLogger("ELVIS") | |
| NUMEXPR_CONSTANTS = { | |
| 'pi': np.pi, | |
| 'PI': np.pi, | |
| 'e': np.e, | |
| } | |
| def eval_kernel(kernel_str): | |
| # List of allowed kernel constructors | |
| allowed_names = { | |
| 'RBF': RBF, | |
| 'Matern': Matern, | |
| 'RationalQuadratic': RationalQuadratic, | |
| 'ExpSineSquared': ExpSineSquared, | |
| 'DotProduct': DotProduct, | |
| 'WhiteKernel': WhiteKernel, | |
| 'ConstantKernel': ConstantKernel, | |
| } | |
| # Parse and check the syntax safely | |
| try: | |
| tree = ast.parse(kernel_str, mode='eval') | |
| except SyntaxError as e: | |
| raise ValueError(f"Invalid syntax: {e}") | |
| # Evaluate in restricted namespace | |
| try: | |
| result = eval(compile(tree, '<string>', 'eval'), | |
| {"__builtins__": None}, # disable access to Python builtins like open | |
| allowed_names # only allow things in this list | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Error evaluating kernel: {e}") | |
| return result | |
| def get_function(function, xlim=(-1, 1), nsample=100): | |
| x = np.linspace(xlim[0], xlim[1], nsample) | |
| y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) | |
| x = x.reshape(-1, 1) | |
| return x, y | |
| def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0): | |
| num_points_to_generate = 100 | |
| if nsample > num_points_to_generate: | |
| raise ValueError(f"nsample too large, limit to {num_points_to_generate}") | |
| rng = np.random.default_rng(seed) | |
| x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate) | |
| x = x[:nsample] | |
| x = np.sort(x) | |
| rng = np.random.default_rng(seed) | |
| noise = sigma * rng.standard_normal(nsample) | |
| y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise | |
| x = x.reshape(-1, 1) | |
| return x, y | |
| def make_sine(xlim=(0,1), nsample=20, sigma=0.1, uniform=False, sort=True): | |
| np.random.seed(42) | |
| if uniform: | |
| X = np.linspace(xlim[0], xlim[1], nsample) | |
| else: | |
| X = xlim[0] + (xlim[1]-xlim[0])*np.random.rand(nsample) | |
| if sort: | |
| X = np.sort(X) | |
| y = np.sin(2*np.pi*X) + sigma*np.random.randn(nsample) | |
| X = X.reshape(-1, 1) | |
| return X, y | |
| class GPVisualizer: | |
| DEFAULT_KERNEL = "RBF() + WhiteKernel()" | |
| DEFAULT_FUNCTION = "sin(2 * pi * x)" | |
| def _init_state(self): | |
| self.data_options = { | |
| "function": self.DEFAULT_FUNCTION, | |
| "nsample": 30, | |
| "sigma": 0, | |
| "seed": 0, | |
| "x_min": -1, | |
| "x_max": 1, | |
| } | |
| self.kernel = eval_kernel(self.DEFAULT_KERNEL) | |
| self.x_train, self.y_train = self.generate_data() | |
| self.model = self.train_model(self.kernel, self.x_train, self.y_train) | |
| self.plot_options = { | |
| "show_training_data": True, | |
| "show_confidence_interval": True, | |
| "show_true_function": True, | |
| "show_predictions": True, | |
| } | |
| self.num_y_samples = 0 | |
| self._y_samples_cache = [] | |
| def __init__(self, width, height): | |
| self.canvas_width = width | |
| self.canvas_height = height | |
| self._init_state() | |
| self.plot_cmap = plt.get_cmap("tab20") | |
| self.css = """ | |
| .hidden-button { | |
| display: none; | |
| }""" | |
| def on_load(self): | |
| self._init_state() | |
| def generate_data(self): | |
| function = self.data_options["function"] | |
| nsample = self.data_options["nsample"] | |
| sigma = self.data_options["sigma"] | |
| x_min = self.data_options["x_min"] | |
| x_max = self.data_options["x_max"] | |
| return get_data_points(function, xlim=(x_min, x_max), nsample=nsample, sigma=sigma, seed=self.data_options["seed"]) | |
| def train_model(self, kernel, x_train, y_train): | |
| gpr = GaussianProcessRegressor(kernel=kernel, random_state=0) | |
| logger.info('fitting ' + str(gpr)) | |
| if len(x_train) > 0: | |
| gpr.fit(x_train, y_train) | |
| return gpr | |
| def plot(self): | |
| ''' | |
| ''' | |
| logger.info("Initializing figure") | |
| fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=100) | |
| # set entire figure to be the canvas to allow simple conversion of mouse | |
| # position to coordinates in the figure | |
| ax = fig.add_axes([0., 0., 1., 1.]) # | |
| ax.margins(x=0, y=0) # no padding in both directions | |
| x_test, y_test = get_function(self.data_options["function"], xlim=(-2, 2), nsample=100) | |
| y_pred, y_std = self.model.predict(x_test, return_std=True) | |
| # plot | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.set_title("") | |
| ax.set_xlabel("x") | |
| ax.set_ylabel("y") | |
| if len(self.x_train) > 1: | |
| R2 = self.model.score(self.x_train, self.y_train) | |
| if self.plot_options["show_training_data"]: | |
| if len(self.x_train) > 1: | |
| plt.scatter(self.x_train.flatten(), self.y_train, label='training data (R2=%.2f)' % (R2), color=self.plot_cmap(0)) | |
| else: | |
| plt.scatter(self.x_train.flatten(), self.y_train, label='training data', color=self.plot_cmap(0)) | |
| if self.plot_options["show_true_function"]: | |
| plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1)) | |
| if self.plot_options["show_predictions"]: | |
| plt.plot(x_test.flatten(), y_pred, linestyle="--", label='mean prediction', color=self.plot_cmap(2)) | |
| if self.plot_options["show_confidence_interval"]: | |
| plt.fill_between( | |
| x_test.flatten(), | |
| y_pred - 1.96 * y_std, | |
| y_pred + 1.96 * y_std, | |
| alpha=0.2, | |
| label='95% confidence interval', | |
| color=self.plot_cmap(3) | |
| ) | |
| for i in range(self.num_y_samples): | |
| if i < len(self._y_samples_cache): | |
| y_sample = self._y_samples_cache[i] | |
| else: | |
| y_sample = self.model.sample_y(x_test, random_state=i).flatten() | |
| self._y_samples_cache.append(y_sample) | |
| plt.plot(x_test.flatten(), y_sample, linestyle=":", label=f"sample {i}", color=self.plot_cmap(4)) | |
| plt.legend() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| def _update_data_seed(self): | |
| self.data_options["seed"] += 1 | |
| self.x_train, self.y_train = self.generate_data() | |
| self.update_model() | |
| return self.plot() | |
| def update_model(self): | |
| self.model = self.train_model(self.kernel, self.x_train, self.y_train) | |
| self.clear_y_samples() | |
| def update_data_options(self, **kwargs): | |
| for key, value in kwargs.items(): | |
| if key in self.data_options: | |
| # if function - test if valid | |
| if key == "function": | |
| try: | |
| x = np.linspace(-1, 1, 10) | |
| y = numexpr.evaluate(value, local_dict={'x': x, **NUMEXPR_CONSTANTS}) | |
| except Exception as e: | |
| raise ValueError(f"Invalid function: {e}") | |
| self.data_options[key] = value | |
| # reset data and model | |
| self.x_train, self.y_train = self.generate_data() | |
| self.update_model() | |
| return self.plot() | |
| def update_kernel_spec(self, kernel_spec): | |
| self.kernel = eval_kernel(kernel_spec) | |
| self.update_model() | |
| return self.plot() | |
| def update_plot_options(self, **kwargs): | |
| for key, value in kwargs.items(): | |
| if key in self.plot_options: | |
| self.plot_options[key] = value | |
| return self.plot() | |
| def add_y_sample(self): | |
| self.num_y_samples += 1 | |
| return self.plot() | |
| def clear_y_samples(self): | |
| self.num_y_samples = 0 | |
| self._y_samples_cache.clear() | |
| return self.plot() | |
| def launch(self): | |
| # build the Gradio interface | |
| with gr.Blocks(css=self.css) as demo: | |
| # app title | |
| gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Gaussian Process Visualizer</div>") | |
| # GUI elements and layout | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| self.canvas = gr.Image(value=self.plot(), | |
| show_download_button=False, | |
| container=True) | |
| with gr.Column(scale=1): | |
| with gr.Tab("Dataset"): | |
| dataset_radio = gr.Radio( | |
| ["Generate", "Upload"], | |
| value="Generate", | |
| label="Dataset", | |
| ) | |
| with gr.Column(): | |
| function_box = gr.Textbox( | |
| label="Function", | |
| placeholder="function of x", | |
| value=self.DEFAULT_FUNCTION, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| x_min = gr.Number( | |
| label="Min x", | |
| value=-1, | |
| interactive=True, | |
| ) | |
| x_max = gr.Number( | |
| label="Max x", | |
| value=1, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| noise_value = gr.Number( | |
| label="Gaussian noise standard deviation", | |
| value=0, | |
| interactive=True, | |
| ) | |
| num_points_slider = gr.Slider( | |
| label="Number of data points", | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| value=30, | |
| interactive=True, | |
| ) | |
| regenerate_button = gr.Button("Regenerate Data") | |
| # upload data | |
| file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight") | |
| self.file_chooser = file_chooser | |
| with gr.Tab("Model"): | |
| # kernel spec | |
| kernel_spec = gr.Textbox( | |
| label="Kernel", | |
| placeholder="sklearn kernel code", | |
| value=self.DEFAULT_KERNEL, | |
| interactive=True, | |
| ) | |
| with gr.Tab("Plot"): | |
| # plot show options | |
| with gr.Column(): | |
| with gr.Row(): | |
| show_training_data = gr.Checkbox(label="Show training data", value=True) | |
| show_true_function = gr.Checkbox(label="Show true function", value=True) | |
| with gr.Row(): | |
| show_predictions = gr.Checkbox(label="Show mean prediction", value=True) | |
| show_confidence_interval = gr.Checkbox(label="Show confidence interval", value=True) | |
| #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines())) | |
| # sampling from GP | |
| sample_button = gr.Button("Sample from GP") | |
| clear_samples_button = gr.Button("Clear samples from GP") | |
| with gr.Tab("Export"): | |
| # use hidden download button to generate files on the fly | |
| # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 | |
| btn_export_data = gr.Button("Data") | |
| btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button") | |
| btn_export_model = gr.Button('Model') | |
| btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button") | |
| btn_export_code = gr.Button('Code') | |
| btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button") | |
| with gr.Tab("Usage"): | |
| gr.Markdown(''.join(open('usage.md', 'r').readlines())) | |
| # data options | |
| function_box.submit( | |
| fn=lambda function: self.update_data_options(function=function), | |
| inputs=function_box, | |
| outputs=[self.canvas], | |
| ) | |
| x_min.submit( | |
| fn=lambda xmin: self.update_data_options(x_min=xmin), | |
| inputs=x_min, | |
| outputs=[self.canvas], | |
| ) | |
| x_max.submit( | |
| fn=lambda xmax: self.update_data_options(x_max=xmax), | |
| inputs=x_max, | |
| outputs=[self.canvas], | |
| ) | |
| num_points_slider.change( | |
| fn=lambda nsample: self.update_data_options(nsample=nsample), | |
| inputs=num_points_slider, | |
| outputs=[self.canvas], | |
| ) | |
| noise_value.submit( | |
| fn=lambda sigma: self.update_data_options(sigma=sigma), | |
| inputs=noise_value, | |
| outputs=[self.canvas], | |
| ) | |
| regenerate_button.click( | |
| fn=self._update_data_seed, | |
| outputs=[self.canvas], | |
| ) | |
| # model options | |
| kernel_spec.submit( | |
| fn=self.update_kernel_spec, | |
| inputs=kernel_spec, | |
| outputs=[self.canvas], | |
| ) | |
| # plot options | |
| show_training_data.change( | |
| fn=lambda show: self.update_plot_options(show_training_data=show), | |
| inputs=show_training_data, | |
| outputs=[self.canvas], | |
| ) | |
| show_confidence_interval.change( | |
| fn=lambda show: self.update_plot_options(show_confidence_interval=show), | |
| inputs=show_confidence_interval, | |
| outputs=[self.canvas], | |
| ) | |
| show_true_function.change( | |
| fn=lambda show: self.update_plot_options(show_true_function=show), | |
| inputs=show_true_function, | |
| outputs=[self.canvas], | |
| ) | |
| show_predictions.change( | |
| fn=lambda show: self.update_plot_options(show_predictions=show), | |
| inputs=show_predictions, | |
| outputs=[self.canvas], | |
| ) | |
| # sampling from GP | |
| sample_button.click( | |
| fn=self.add_y_sample, | |
| outputs=[self.canvas], | |
| ) | |
| clear_samples_button.click( | |
| fn=self.clear_y_samples, | |
| outputs=[self.canvas], | |
| ) | |
| demo.load(self.on_load) | |
| demo.launch() | |
| visualizer = GPVisualizer(width=1200, height=900) | |
| visualizer.launch() | |