Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| import time | |
| import gradio as gr | |
| import io | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) | |
| format="%(asctime)s [%(levelname)s] %(message)s", # log format | |
| ) | |
| logger = logging.getLogger("ELVIS") | |
| from new_architecture import Architecture, ArchitectureView | |
| from dataset import Dataset, DatasetView, get_function | |
| from hyperparameters import Hyperparameters, HyperparametersView | |
| class TrainState: | |
| model: nn.Module | |
| optimizer: torch.optim.Optimizer | |
| last_loss: float = float("nan") | |
| steps_trained: int = 0 | |
| def __hash__(self): | |
| return hash( | |
| ( | |
| str(self.model), | |
| str(self.optimizer), | |
| self.last_loss, | |
| self.steps_trained, | |
| ) | |
| ) | |
| def init_model(architecture: Architecture) -> nn.Module: | |
| input_size = 1 | |
| output_size = 1 | |
| layers = [] | |
| for hidden_units, activation in zip(architecture.hidden_units, architecture.activations): | |
| layers.append(nn.Linear(input_size, hidden_units)) | |
| activation = ( | |
| activation | |
| .lower() | |
| .replace(" ", "") | |
| .replace("-", "") | |
| .replace("_", "") | |
| ) | |
| if activation == "relu": | |
| layers.append(nn.ReLU()) | |
| elif activation == "sigmoid": | |
| layers.append(nn.Sigmoid()) | |
| elif activation == "tanh": | |
| layers.append(nn.Tanh()) | |
| elif activation == "leakyrelu": | |
| layers.append(nn.LeakyReLU()) | |
| elif activation == "elu": | |
| layers.append(nn.ELU()) | |
| elif activation == "gelu": | |
| layers.append(nn.GELU()) | |
| elif activation == "identity": | |
| layers.append(nn.Identity()) | |
| else: | |
| raise ValueError(f"Unknown activation: {activation}") | |
| input_size = hidden_units | |
| layers.append(nn.Linear(input_size, output_size)) | |
| model = nn.Sequential(*layers) | |
| return model | |
| def init_optimizer( | |
| model: nn.Module, | |
| hyperparameters: Hyperparameters, | |
| ) -> torch.optim.Optimizer: | |
| if hyperparameters.optimizer == "SGD": | |
| opt = torch.optim.SGD( | |
| model.parameters(), | |
| lr=hyperparameters.sgd_params.learning_rate, | |
| momentum=hyperparameters.sgd_params.momentum, | |
| weight_decay=hyperparameters.sgd_params.weight_decay, | |
| ) | |
| elif hyperparameters.optimizer == "Adam": | |
| opt = torch.optim.Adam( | |
| model.parameters(), | |
| lr=hyperparameters.adam_params.learning_rate, | |
| betas=(hyperparameters.adam_params.beta1, hyperparameters.adam_params.beta2), | |
| weight_decay=hyperparameters.adam_params.weight_decay, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown optimizer: {hyperparameters.optimizer}") | |
| return opt | |
| class PlotOptions: | |
| show_training_data: bool = True | |
| show_true_function: bool = True | |
| show_prediction: bool = True | |
| def update(self, **kwargs): | |
| return PlotOptions( | |
| show_training_data=kwargs.get("show_training_data", self.show_training_data), | |
| show_true_function=kwargs.get("show_true_function", self.show_true_function), | |
| show_prediction=kwargs.get("show_prediction", self.show_prediction), | |
| ) | |
| def __hash__(self): | |
| return hash( | |
| ( | |
| self.show_training_data, | |
| self.show_true_function, | |
| self.show_prediction, | |
| ) | |
| ) | |
| class MlpVisualizer: | |
| def __init__(self, width, height): | |
| self.canvas_width = width | |
| self.canvas_height = height | |
| self.plot_cmap = plt.get_cmap("tab20") | |
| self.css = """ | |
| .hidden-button { | |
| display: none; | |
| }""" | |
| def plot(self, dataset: Dataset, train_state: TrainState, plot_options: PlotOptions) -> Image.Image: | |
| print("Plotting") | |
| t1 = time.time() | |
| fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100) | |
| # set entire figure to be the canvas to allow simple conversion of mouse | |
| # position to coordinates in the figure | |
| ax = fig.add_axes([0., 0., 1., 1.]) # | |
| ax.margins(x=0, y=0) # no padding in both directions | |
| x_train = dataset.x | |
| y_train = dataset.y | |
| if dataset.mode == "generate": | |
| x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100) | |
| y_pred = train_state.model(torch.from_numpy(x_test).float()).detach().numpy() | |
| elif x_train.shape[0] > 0: | |
| x_test = np.linspace(x_train.min() - 1, x_train.max() + 1, 100).reshape(-1, 1) | |
| y_test = None | |
| y_pred = train_state.model(torch.from_numpy(x_test).float()).detach().numpy() | |
| else: | |
| x_test = None | |
| y_test = None | |
| y_pred = None | |
| # plot | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.set_title("") | |
| ax.set_xlabel("x") | |
| ax.set_ylabel("y") | |
| if y_test is not None: | |
| ax.set_ylim(y_test.min() - 1, y_test.max() + 1) | |
| elif y_train.shape[0] > 0: | |
| ax.set_ylim(y_train.min() - 1, y_train.max() + 1) | |
| if plot_options.show_training_data: | |
| plt.scatter(x_train.flatten(), y_train, label='training data', color=self.plot_cmap(0)) | |
| if plot_options.show_true_function and x_test is not None and y_test is not None: | |
| plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1)) | |
| if plot_options.show_prediction and x_test is not None and y_pred is not None: | |
| plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2)) | |
| plt.legend() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close(fig) | |
| t2 = time.time() | |
| logger.info(f"Plotting took {t2 - t1:.4f} seconds") | |
| return img | |
| def update_dataset( | |
| self, | |
| dataset: Dataset, | |
| architecture: Architecture, | |
| hyperparameters: Hyperparameters, | |
| plot_options: PlotOptions, | |
| ): | |
| print("Updating dataset") | |
| new_model = init_model(architecture) | |
| new_optimizer = init_optimizer(new_model, hyperparameters) | |
| new_train_state = TrainState(new_model, new_optimizer) | |
| new_canvas = self.plot(dataset, new_train_state, plot_options) | |
| return new_canvas, new_train_state | |
| def update_architecture( | |
| self, | |
| dataset: Dataset, | |
| architecture: Architecture, | |
| hyperparameters: Hyperparameters, | |
| plot_options: PlotOptions, | |
| ): | |
| print("Updating architecture") | |
| new_model = init_model(architecture) | |
| new_optimizer = init_optimizer(new_model, hyperparameters) | |
| new_train_state = TrainState(new_model, new_optimizer) | |
| new_canvas = self.plot(dataset, new_train_state, plot_options) | |
| return new_canvas, new_train_state | |
| def update_hyperparameters( | |
| self, | |
| dataset: Dataset, | |
| architecture: Architecture, | |
| hyperparameters: Hyperparameters, | |
| plot_options: PlotOptions, | |
| ): | |
| print("Updating hyperparameters") | |
| new_model = init_model(architecture) | |
| new_optimizer = init_optimizer(new_model, hyperparameters) | |
| new_train_state = TrainState(new_model, new_optimizer) | |
| new_canvas = self.plot(dataset, new_train_state, plot_options) | |
| return new_canvas, new_train_state | |
| def train_step( | |
| self, | |
| dataset: Dataset, | |
| hyperparameters: Hyperparameters, | |
| train_state: TrainState, | |
| plot_options: PlotOptions, | |
| num_steps: int = 1, | |
| ): | |
| print("Training step") | |
| model = train_state.model | |
| optimizer = train_state.optimizer | |
| batch_size = hyperparameters.batch_size | |
| model.train() | |
| x_train = torch.from_numpy(dataset.x).float() | |
| y_train = torch.from_numpy(dataset.y).float() | |
| for _ in range(num_steps): | |
| indices = torch.randperm(x_train.shape[0])[:batch_size] | |
| x_batch = x_train[indices] | |
| y_batch = y_train[indices] | |
| y_pred = model(x_batch) | |
| loss = nn.MSELoss()(y_pred.flatten(), y_batch) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| print(f"Training loss: {loss.item():.4f}") | |
| train_state.last_loss = loss.item() | |
| train_state.steps_trained += num_steps | |
| new_canvas = self.plot(dataset, train_state, plot_options) | |
| return new_canvas, train_state | |
| def reset_model( | |
| self, | |
| dataset: Dataset, | |
| architecture: Architecture, | |
| hyperparameters: Hyperparameters, | |
| plot_options: PlotOptions, | |
| ): | |
| print("Resetting model") | |
| new_model = init_model(architecture) | |
| new_optimizer = init_optimizer(new_model, hyperparameters) | |
| new_train_state = TrainState(new_model, new_optimizer) | |
| new_canvas = self.plot(dataset, new_train_state, plot_options) | |
| return new_canvas, new_train_state | |
| def launch(self): | |
| # build the Gradio interface | |
| with gr.Blocks(css=self.css) as demo: | |
| # app title | |
| gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>") | |
| # states | |
| dataset = gr.State(Dataset()) | |
| architecture = gr.State(Architecture()) | |
| hyperparameters = gr.State(Hyperparameters()) | |
| model = init_model(architecture.value) | |
| optimizer = init_optimizer(model, hyperparameters.value) | |
| train_state = gr.State(TrainState(model, optimizer)) | |
| plot_options = gr.State(PlotOptions()) | |
| # GUI elements and layout | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| canvas = gr.Image( | |
| value=self.plot(dataset.value, train_state.value, plot_options.value), | |
| # show_download_button=False, | |
| container=True, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Tab("Dataset"): | |
| dataset_view = DatasetView() | |
| dataset_view.build(state=dataset) | |
| dataset.change( | |
| fn=self.update_dataset, | |
| inputs=[dataset, architecture, hyperparameters, plot_options], | |
| outputs=[canvas, train_state], | |
| ) | |
| with gr.Tab("Architecture"): | |
| architecture_view = ArchitectureView() | |
| architecture_view.build(state=architecture) | |
| architecture.change( | |
| fn=self.update_architecture, | |
| inputs=[dataset, architecture, hyperparameters, plot_options], | |
| outputs=[canvas, train_state], | |
| ) | |
| with gr.Tab("Train"): | |
| hyperparameters_view = HyperparametersView() | |
| hyperparameters_view.build(state=hyperparameters) | |
| hyperparameters.change( | |
| fn=self.update_hyperparameters, | |
| inputs=[dataset, architecture, hyperparameters, plot_options], | |
| outputs=[canvas, train_state], | |
| ) | |
| train_increment = gr.Number( | |
| label="Step increment", | |
| value=1, | |
| precision=0, | |
| ) | |
| train_button = gr.Button("Train step") | |
| reset_button = gr.Button("Reset Model") | |
| with gr.Row(): | |
| train_step = gr.Number( | |
| label="Steps trained", | |
| value=train_state.value.steps_trained, | |
| interactive=False, | |
| ) | |
| train_loss = gr.Number( | |
| label="Last Loss", | |
| value=train_state.value.last_loss, | |
| interactive=False, | |
| ) | |
| train_button.click( | |
| fn=self.train_step, | |
| inputs=[dataset, hyperparameters, train_state, plot_options, train_increment], | |
| outputs=[canvas, train_state], | |
| ) | |
| reset_button.click( | |
| fn=self.reset_model, | |
| inputs=[dataset, architecture, hyperparameters, plot_options], | |
| outputs=[canvas, train_state], | |
| ) | |
| train_state.change( | |
| fn=lambda state: (state.steps_trained, state.last_loss), | |
| inputs=[train_state], | |
| outputs=[train_step, train_loss], | |
| ) | |
| with gr.Tab("Plot Options"): | |
| show_training_data = gr.Checkbox( | |
| label="Show Training Data", | |
| value=True, | |
| ) | |
| show_true_function = gr.Checkbox( | |
| label="Show True Function", | |
| value=True, | |
| ) | |
| show_prediction = gr.Checkbox( | |
| label="Show Prediction", | |
| value=True, | |
| ) | |
| show_training_data.change( | |
| fn=lambda val, options: options.update(show_training_data=val), | |
| inputs=[show_training_data, plot_options], | |
| outputs=[plot_options], | |
| ) | |
| show_true_function.change( | |
| fn=lambda val, options: options.update(show_true_function=val), | |
| inputs=[show_true_function, plot_options], | |
| outputs=[plot_options], | |
| ) | |
| show_prediction.change( | |
| fn=lambda val, options: options.update(show_prediction=val), | |
| inputs=[show_prediction, plot_options], | |
| outputs=[plot_options], | |
| ) | |
| plot_options.change( | |
| fn=self.plot, | |
| inputs=[dataset, train_state, plot_options], | |
| outputs=[canvas], | |
| ) | |
| with gr.Tab("Usage"): | |
| with open("usage.md") as f: | |
| usage_text = f.read() | |
| gr.Markdown(usage_text) | |
| demo.launch() | |
| visualizer = MlpVisualizer(width=1200, height=900) | |
| visualizer.launch() | |