from dataclasses import dataclass import time import gradio as gr import io import matplotlib.pyplot as plt import numpy as np from PIL import Image import torch import torch.nn as nn import logging logging.basicConfig( level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) format="%(asctime)s [%(levelname)s] %(message)s", # log format ) logger = logging.getLogger("ELVIS") from new_architecture import Architecture, ArchitectureView from dataset import Dataset, DatasetView, get_function from hyperparameters import Hyperparameters, HyperparametersView @dataclass class TrainState: model: nn.Module optimizer: torch.optim.Optimizer last_loss: float = float("nan") steps_trained: int = 0 def __hash__(self): return hash( ( str(self.model), str(self.optimizer), self.last_loss, self.steps_trained, ) ) def init_model(architecture: Architecture) -> nn.Module: input_size = 1 output_size = 1 layers = [] for hidden_units, activation in zip(architecture.hidden_units, architecture.activations): layers.append(nn.Linear(input_size, hidden_units)) activation = ( activation .lower() .replace(" ", "") .replace("-", "") .replace("_", "") ) if activation == "relu": layers.append(nn.ReLU()) elif activation == "sigmoid": layers.append(nn.Sigmoid()) elif activation == "tanh": layers.append(nn.Tanh()) elif activation == "leakyrelu": layers.append(nn.LeakyReLU()) elif activation == "elu": layers.append(nn.ELU()) elif activation == "gelu": layers.append(nn.GELU()) elif activation == "identity": layers.append(nn.Identity()) else: raise ValueError(f"Unknown activation: {activation}") input_size = hidden_units layers.append(nn.Linear(input_size, output_size)) model = nn.Sequential(*layers) return model def init_optimizer( model: nn.Module, hyperparameters: Hyperparameters, ) -> torch.optim.Optimizer: if hyperparameters.optimizer == "SGD": opt = torch.optim.SGD( model.parameters(), lr=hyperparameters.sgd_params.learning_rate, momentum=hyperparameters.sgd_params.momentum, weight_decay=hyperparameters.sgd_params.weight_decay, ) elif hyperparameters.optimizer == "Adam": opt = torch.optim.Adam( model.parameters(), lr=hyperparameters.adam_params.learning_rate, betas=(hyperparameters.adam_params.beta1, hyperparameters.adam_params.beta2), weight_decay=hyperparameters.adam_params.weight_decay, ) else: raise ValueError(f"Unknown optimizer: {hyperparameters.optimizer}") return opt @dataclass(frozen=True) class PlotOptions: show_training_data: bool = True show_true_function: bool = True show_prediction: bool = True def update(self, **kwargs): return PlotOptions( show_training_data=kwargs.get("show_training_data", self.show_training_data), show_true_function=kwargs.get("show_true_function", self.show_true_function), show_prediction=kwargs.get("show_prediction", self.show_prediction), ) def __hash__(self): return hash( ( self.show_training_data, self.show_true_function, self.show_prediction, ) ) class MlpVisualizer: def __init__(self, width, height): self.canvas_width = width self.canvas_height = height self.plot_cmap = plt.get_cmap("tab20") self.css = """ .hidden-button { display: none; }""" def plot(self, dataset: Dataset, train_state: TrainState, plot_options: PlotOptions) -> Image.Image: print("Plotting") t1 = time.time() fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100) # set entire figure to be the canvas to allow simple conversion of mouse # position to coordinates in the figure ax = fig.add_axes([0., 0., 1., 1.]) # ax.margins(x=0, y=0) # no padding in both directions x_train = dataset.x y_train = dataset.y if dataset.mode == "generate": x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100) y_pred = train_state.model(torch.from_numpy(x_test).float()).detach().numpy() elif x_train.shape[0] > 0: x_test = np.linspace(x_train.min() - 1, x_train.max() + 1, 100).reshape(-1, 1) y_test = None y_pred = train_state.model(torch.from_numpy(x_test).float()).detach().numpy() else: x_test = None y_test = None y_pred = None # plot fig, ax = plt.subplots(figsize=(8, 8)) ax.set_title("") ax.set_xlabel("x") ax.set_ylabel("y") if y_test is not None: ax.set_ylim(y_test.min() - 1, y_test.max() + 1) elif y_train.shape[0] > 0: ax.set_ylim(y_train.min() - 1, y_train.max() + 1) if plot_options.show_training_data: plt.scatter(x_train.flatten(), y_train, label='training data', color=self.plot_cmap(0)) if plot_options.show_true_function and x_test is not None and y_test is not None: plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1)) if plot_options.show_prediction and x_test is not None and y_pred is not None: plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2)) plt.legend() buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) img = Image.open(buf) plt.close(fig) t2 = time.time() logger.info(f"Plotting took {t2 - t1:.4f} seconds") return img def update_dataset( self, dataset: Dataset, architecture: Architecture, hyperparameters: Hyperparameters, plot_options: PlotOptions, ): print("Updating dataset") new_model = init_model(architecture) new_optimizer = init_optimizer(new_model, hyperparameters) new_train_state = TrainState(new_model, new_optimizer) new_canvas = self.plot(dataset, new_train_state, plot_options) return new_canvas, new_train_state def update_architecture( self, dataset: Dataset, architecture: Architecture, hyperparameters: Hyperparameters, plot_options: PlotOptions, ): print("Updating architecture") new_model = init_model(architecture) new_optimizer = init_optimizer(new_model, hyperparameters) new_train_state = TrainState(new_model, new_optimizer) new_canvas = self.plot(dataset, new_train_state, plot_options) return new_canvas, new_train_state def update_hyperparameters( self, dataset: Dataset, architecture: Architecture, hyperparameters: Hyperparameters, plot_options: PlotOptions, ): print("Updating hyperparameters") new_model = init_model(architecture) new_optimizer = init_optimizer(new_model, hyperparameters) new_train_state = TrainState(new_model, new_optimizer) new_canvas = self.plot(dataset, new_train_state, plot_options) return new_canvas, new_train_state def train_step( self, dataset: Dataset, hyperparameters: Hyperparameters, train_state: TrainState, plot_options: PlotOptions, num_steps: int = 1, ): print("Training step") model = train_state.model optimizer = train_state.optimizer batch_size = hyperparameters.batch_size model.train() x_train = torch.from_numpy(dataset.x).float() y_train = torch.from_numpy(dataset.y).float() for _ in range(num_steps): indices = torch.randperm(x_train.shape[0])[:batch_size] x_batch = x_train[indices] y_batch = y_train[indices] y_pred = model(x_batch) loss = nn.MSELoss()(y_pred.flatten(), y_batch) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Training loss: {loss.item():.4f}") train_state.last_loss = loss.item() train_state.steps_trained += num_steps new_canvas = self.plot(dataset, train_state, plot_options) return new_canvas, train_state def reset_model( self, dataset: Dataset, architecture: Architecture, hyperparameters: Hyperparameters, plot_options: PlotOptions, ): print("Resetting model") new_model = init_model(architecture) new_optimizer = init_optimizer(new_model, hyperparameters) new_train_state = TrainState(new_model, new_optimizer) new_canvas = self.plot(dataset, new_train_state, plot_options) return new_canvas, new_train_state def launch(self): # build the Gradio interface with gr.Blocks(css=self.css) as demo: # app title gr.HTML("