gp_visualizer / old_code /gp_visualizer.py
joel-woodfield's picture
Refactor file structure
03e72c7
from dataclasses import dataclass
import time
import ast
import gradio as gr
import io
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
DotProduct,
WhiteKernel,
ConstantKernel,
RBF,
Matern,
RationalQuadratic,
ExpSineSquared,
Kernel,
)
import logging
logging.basicConfig(
level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
format="%(asctime)s [%(levelname)s] %(message)s", # log format
)
logger = logging.getLogger("ELVIS")
from dataset import Dataset, DatasetView, get_function
@dataclass(frozen=True)
class PlotOptions:
show_training_data: bool = True
show_true_function: bool = True
show_mean_prediction: bool = True
show_prediction_interval: bool = True
def update(self, **kwargs):
return PlotOptions(
show_training_data=kwargs.get("show_training_data", self.show_training_data),
show_true_function=kwargs.get("show_true_function", self.show_true_function),
show_mean_prediction=kwargs.get("show_mean_prediction", self.show_mean_prediction),
show_prediction_interval=kwargs.get("show_prediction_interval", self.show_prediction_interval),
)
def __hash__(self):
return hash(
(
self.show_training_data,
self.show_true_function,
self.show_mean_prediction,
self.show_prediction_interval,
)
)
def eval_kernel(kernel_str) -> Kernel:
# List of allowed kernel constructors
allowed_names = {
'RBF': RBF,
'Matern': Matern,
'RationalQuadratic': RationalQuadratic,
'ExpSineSquared': ExpSineSquared,
'DotProduct': DotProduct,
'WhiteKernel': WhiteKernel,
'ConstantKernel': ConstantKernel,
}
# Parse and check the syntax safely
try:
tree = ast.parse(kernel_str, mode='eval')
except SyntaxError as e:
raise ValueError(f"Invalid syntax: {e}")
# Evaluate in restricted namespace
try:
result = eval(
compile(tree, '<string>', 'eval'),
{"__builtins__": None}, # disable access to Python builtins like open
allowed_names # only allow things in this list
)
except Exception as e:
raise ValueError(f"Error evaluating kernel: {e}")
return result
@dataclass
class ModelState:
model: GaussianProcessRegressor
kernel: str
distribution: str
def __hash__(self):
return hash(
(
self.kernel,
self.distribution,
)
)
class GpVisualizer:
def __init__(self, width, height):
self.canvas_width = width
self.canvas_height = height
self.plot_cmap = plt.get_cmap("tab20")
self.css = """
.hidden-button {
display: none;
}"""
def plot(
self,
dataset: Dataset,
model_state: ModelState,
plot_options: PlotOptions,
sample_y: bool = False,
sample_y_seed: int = 0,
) -> Image.Image:
print("Plotting")
t1 = time.time()
fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
# set entire figure to be the canvas to allow simple conversion of mouse
# position to coordinates in the figure
ax = fig.add_axes([0., 0., 1., 1.]) #
ax.margins(x=0, y=0) # no padding in both directions
x_train = dataset.x
y_train = dataset.y
if dataset.mode == "generate":
x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100)
y_pred, y_std = model_state.model.predict(x_test, return_std=True)
elif x_train.shape[0] > 0:
x_test = np.linspace(x_train.min() - 1, x_train.max() + 1, 100).reshape(-1, 1)
y_test = None
y_pred, y_std = model_state.model.predict(x_test, return_std=True)
else:
x_test = None
y_test = None
y_pred = None
y_std = None
# plot
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_title("")
ax.set_xlabel("x")
ax.set_ylabel("y")
if y_test is not None:
min_y = min(y_test.min(), (y_pred - 1.96 * y_std).min())
max_y = max(y_test.max(), (y_pred + 1.96 * y_std).max())
ax.set_ylim(min_y - 1, max_y + 1)
elif y_train.shape[0] > 0:
min_y = min(y_train.min(), (y_pred - 1.96 * y_std).min())
max_y = max(y_train.max(), (y_pred + 1.96 * y_std).max())
ax.set_ylim(min_y - 1, max_y + 1)
if plot_options.show_training_data:
plt.scatter(
x_train.flatten(),
y_train,
label='training data',
color=self.plot_cmap(0),
)
if plot_options.show_true_function and x_test is not None and y_test is not None:
plt.plot(
x_test.flatten(),
y_test,
label='true function',
color=self.plot_cmap(1),
)
if plot_options.show_mean_prediction and x_test is not None and y_pred is not None:
plt.plot(
x_test.flatten(),
y_pred,
linestyle="--",
label='mean prediction',
color=self.plot_cmap(2),
)
if plot_options.show_prediction_interval and x_test is not None and y_std is not None:
plt.fill_between(
x_test.flatten(),
y_pred - 1.96 * y_std,
y_pred + 1.96 * y_std,
color=self.plot_cmap(3),
alpha=0.2,
label='95% prediction interval',
)
if x_test is not None and sample_y:
y_sample = model_state.model.sample_y(
x_test, random_state=sample_y_seed
).flatten()
plt.plot(
x_test.flatten(),
y_sample,
linestyle=":",
label="model sample",
color=self.plot_cmap(4),
)
plt.legend()
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
t2 = time.time()
logger.info(f"Plotting took {t2 - t1:.4f} seconds")
return img
def init_model(
self,
kernel: str,
dataset: Dataset,
distribution: str,
) -> GaussianProcessRegressor:
model = GaussianProcessRegressor(kernel=eval_kernel(kernel))
if distribution == "posterior":
if dataset.x.shape[0] > 0:
model.fit(dataset.x, dataset.y)
elif distribution != "prior":
raise ValueError(f"Unknown distribution: {distribution}")
return model
def update_dataset(
self,
dataset: Dataset,
model_state: ModelState,
plot_options: PlotOptions,
) -> tuple[ModelState, Image.Image]:
print("updating dataset")
model = self.init_model(
model_state.kernel,
dataset,
model_state.distribution,
)
model_state = ModelState(
model=model, kernel=model_state.kernel, distribution=model_state.distribution
)
new_canvas = self.plot(dataset, model_state, plot_options)
return model_state, new_canvas
def update_model(
self,
kernel_str: str,
distribution: str,
model_state: ModelState,
dataset: Dataset,
plot_options: PlotOptions,
) -> tuple[ModelState, Image.Image]:
print("updating kernel")
try:
model = self.init_model(
kernel_str,
dataset,
distribution.lower(),
)
model_state = ModelState(
model=model, kernel=kernel_str, distribution=distribution.lower()
)
except Exception as e:
logger.error(f"Error updating kernel: {e}")
gr.Info(f" ⚠️ Error updating kerne: {e}")
new_canvas = self.plot(dataset, model_state, plot_options)
return model_state, new_canvas
def sample(
self,
model_state: ModelState,
dataset: Dataset,
plot_options: PlotOptions,
) -> Image.Image:
print("sampling from model")
seed = int(time.time() * 100) % 10000
new_canvas = self.plot(
dataset,
model_state,
plot_options,
sample_y=True,
sample_y_seed=seed,
)
return new_canvas
def clear_sample(
self,
model_state: ModelState,
dataset: Dataset,
plot_options: PlotOptions,
) -> Image.Image:
print("clearing sample from model")
new_canvas = self.plot(
dataset,
model_state,
plot_options,
sample_y=False,
)
return new_canvas
def launch(self):
# build the Gradio interface
with gr.Blocks(css=self.css) as demo:
# app title
gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Gaussian Process Visualizer</div>")
# states
dataset = gr.State(Dataset())
plot_options = gr.State(PlotOptions())
kernel = "RBF() + WhiteKernel()"
model = self.init_model(kernel, dataset.value, "posterior")
model_state = gr.State(
ModelState(model=model, kernel=kernel, distribution="posterior")
)
# GUI elements and layout
with gr.Row():
with gr.Column(scale=2):
canvas = gr.Image(
value=self.plot(
dataset.value,
model_state.value,
plot_options.value,
),
# show_download_button=False,
container=True,
)
with gr.Column(scale=1):
with gr.Tab("Dataset"):
dataset_view = DatasetView()
dataset_view.build(state=dataset)
dataset.change(
fn=self.update_dataset,
inputs=[dataset, model_state, plot_options],
outputs=[model_state, canvas],
)
with gr.Tab("Model"):
kernel_box = gr.Textbox(
label="Kernel",
value=model_state.value.kernel,
interactive=True,
)
kernel_submit = gr.Button("Update Kernel")
distribution = gr.Radio(
label="Distribution",
choices=["Prior", "Posterior"],
value="Posterior",
)
kernel_box.submit(
fn=self.update_model,
inputs=[kernel_box, distribution, model_state, dataset, plot_options],
outputs=[model_state, canvas],
)
kernel_submit.click(
fn=self.update_model,
inputs=[kernel_box, distribution, model_state, dataset, plot_options],
outputs=[model_state, canvas],
)
distribution.change(
fn=self.update_model,
inputs=[kernel_box, distribution, model_state, dataset, plot_options],
outputs=[model_state, canvas],
)
sample_button = gr.Button("Sample")
clear_sample_button = gr.Button("Clear Sample")
sample_button.click(
fn=self.sample,
inputs=[model_state, dataset, plot_options],
outputs=[canvas],
)
clear_sample_button.click(
fn=self.clear_sample,
inputs=[model_state, dataset, plot_options],
outputs=[canvas],
)
with gr.Tab("Plot Options"):
show_training_data = gr.Checkbox(
label="Show Training Data",
value=True,
)
show_true_function = gr.Checkbox(
label="Show True Function",
value=True,
)
show_mean_prediction = gr.Checkbox(
label="Show Mean Prediction",
value=True,
)
show_prediction_interval = gr.Checkbox(
label="Show Prediction Interval",
value=True,
)
show_training_data.change(
fn=lambda val, options: options.update(show_training_data=val),
inputs=[show_training_data, plot_options],
outputs=[plot_options],
)
show_true_function.change(
fn=lambda val, options: options.update(show_true_function=val),
inputs=[show_true_function, plot_options],
outputs=[plot_options],
)
show_mean_prediction.change(
fn=lambda val, options: options.update(show_mean_prediction=val),
inputs=[show_mean_prediction, plot_options],
outputs=[plot_options],
)
show_prediction_interval.change(
fn=lambda val, options: options.update(show_prediction_interval=val),
inputs=[show_prediction_interval, plot_options],
outputs=[plot_options],
)
plot_options.change(
fn=self.plot,
inputs=[dataset, model_state, plot_options],
outputs=[canvas],
)
with gr.Tab("Usage"):
with open("usage.md", "r") as f:
usage_md = f.read()
gr.Markdown(usage_md)
demo.launch()
visualizer = GpVisualizer(width=1200, height=900)
visualizer.launch()