mlp_visualizer / old /mlp_visualizer.py
joel-woodfield's picture
Refactor to separate frontend and backend logic
7af7098
from dataclasses import dataclass
import time
import gradio as gr
import io
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import logging
logging.basicConfig(
level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
format="%(asctime)s [%(levelname)s] %(message)s", # log format
)
logger = logging.getLogger("ELVIS")
from new_architecture import Architecture, ArchitectureView
from dataset import Dataset, DatasetView, get_function
from hyperparameters import Hyperparameters, HyperparametersView
@dataclass
class TrainState:
model: nn.Module
optimizer: torch.optim.Optimizer
last_loss: float = float("nan")
steps_trained: int = 0
def __hash__(self):
return hash(
(
str(self.model),
str(self.optimizer),
self.last_loss,
self.steps_trained,
)
)
def init_model(architecture: Architecture) -> nn.Module:
input_size = 1
output_size = 1
layers = []
for hidden_units, activation in zip(architecture.hidden_units, architecture.activations):
layers.append(nn.Linear(input_size, hidden_units))
activation = (
activation
.lower()
.replace(" ", "")
.replace("-", "")
.replace("_", "")
)
if activation == "relu":
layers.append(nn.ReLU())
elif activation == "sigmoid":
layers.append(nn.Sigmoid())
elif activation == "tanh":
layers.append(nn.Tanh())
elif activation == "leakyrelu":
layers.append(nn.LeakyReLU())
elif activation == "elu":
layers.append(nn.ELU())
elif activation == "gelu":
layers.append(nn.GELU())
elif activation == "identity":
layers.append(nn.Identity())
else:
raise ValueError(f"Unknown activation: {activation}")
input_size = hidden_units
layers.append(nn.Linear(input_size, output_size))
model = nn.Sequential(*layers)
return model
def init_optimizer(
model: nn.Module,
hyperparameters: Hyperparameters,
) -> torch.optim.Optimizer:
if hyperparameters.optimizer == "SGD":
opt = torch.optim.SGD(
model.parameters(),
lr=hyperparameters.sgd_params.learning_rate,
momentum=hyperparameters.sgd_params.momentum,
weight_decay=hyperparameters.sgd_params.weight_decay,
)
elif hyperparameters.optimizer == "Adam":
opt = torch.optim.Adam(
model.parameters(),
lr=hyperparameters.adam_params.learning_rate,
betas=(hyperparameters.adam_params.beta1, hyperparameters.adam_params.beta2),
weight_decay=hyperparameters.adam_params.weight_decay,
)
else:
raise ValueError(f"Unknown optimizer: {hyperparameters.optimizer}")
return opt
@dataclass(frozen=True)
class PlotOptions:
show_training_data: bool = True
show_true_function: bool = True
show_prediction: bool = True
def update(self, **kwargs):
return PlotOptions(
show_training_data=kwargs.get("show_training_data", self.show_training_data),
show_true_function=kwargs.get("show_true_function", self.show_true_function),
show_prediction=kwargs.get("show_prediction", self.show_prediction),
)
def __hash__(self):
return hash(
(
self.show_training_data,
self.show_true_function,
self.show_prediction,
)
)
class MlpVisualizer:
def __init__(self, width, height):
self.canvas_width = width
self.canvas_height = height
self.plot_cmap = plt.get_cmap("tab20")
self.css = """
.hidden-button {
display: none;
}"""
def plot(self, dataset: Dataset, train_state: TrainState, plot_options: PlotOptions) -> Image.Image:
print("Plotting")
t1 = time.time()
fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
# set entire figure to be the canvas to allow simple conversion of mouse
# position to coordinates in the figure
ax = fig.add_axes([0., 0., 1., 1.]) #
ax.margins(x=0, y=0) # no padding in both directions
x_train = dataset.x
y_train = dataset.y
if dataset.mode == "generate":
x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100)
y_pred = train_state.model(torch.from_numpy(x_test).float()).detach().numpy()
elif x_train.shape[0] > 0:
x_test = np.linspace(x_train.min() - 1, x_train.max() + 1, 100).reshape(-1, 1)
y_test = None
y_pred = train_state.model(torch.from_numpy(x_test).float()).detach().numpy()
else:
x_test = None
y_test = None
y_pred = None
# plot
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_title("")
ax.set_xlabel("x")
ax.set_ylabel("y")
if y_test is not None:
ax.set_ylim(y_test.min() - 1, y_test.max() + 1)
elif y_train.shape[0] > 0:
ax.set_ylim(y_train.min() - 1, y_train.max() + 1)
if plot_options.show_training_data:
plt.scatter(x_train.flatten(), y_train, label='training data', color=self.plot_cmap(0))
if plot_options.show_true_function and x_test is not None and y_test is not None:
plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
if plot_options.show_prediction and x_test is not None and y_pred is not None:
plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2))
plt.legend()
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
t2 = time.time()
logger.info(f"Plotting took {t2 - t1:.4f} seconds")
return img
def update_dataset(
self,
dataset: Dataset,
architecture: Architecture,
hyperparameters: Hyperparameters,
plot_options: PlotOptions,
):
print("Updating dataset")
new_model = init_model(architecture)
new_optimizer = init_optimizer(new_model, hyperparameters)
new_train_state = TrainState(new_model, new_optimizer)
new_canvas = self.plot(dataset, new_train_state, plot_options)
return new_canvas, new_train_state
def update_architecture(
self,
dataset: Dataset,
architecture: Architecture,
hyperparameters: Hyperparameters,
plot_options: PlotOptions,
):
print("Updating architecture")
new_model = init_model(architecture)
new_optimizer = init_optimizer(new_model, hyperparameters)
new_train_state = TrainState(new_model, new_optimizer)
new_canvas = self.plot(dataset, new_train_state, plot_options)
return new_canvas, new_train_state
def update_hyperparameters(
self,
dataset: Dataset,
architecture: Architecture,
hyperparameters: Hyperparameters,
plot_options: PlotOptions,
):
print("Updating hyperparameters")
new_model = init_model(architecture)
new_optimizer = init_optimizer(new_model, hyperparameters)
new_train_state = TrainState(new_model, new_optimizer)
new_canvas = self.plot(dataset, new_train_state, plot_options)
return new_canvas, new_train_state
def train_step(
self,
dataset: Dataset,
hyperparameters: Hyperparameters,
train_state: TrainState,
plot_options: PlotOptions,
num_steps: int = 1,
):
print("Training step")
model = train_state.model
optimizer = train_state.optimizer
batch_size = hyperparameters.batch_size
model.train()
x_train = torch.from_numpy(dataset.x).float()
y_train = torch.from_numpy(dataset.y).float()
for _ in range(num_steps):
indices = torch.randperm(x_train.shape[0])[:batch_size]
x_batch = x_train[indices]
y_batch = y_train[indices]
y_pred = model(x_batch)
loss = nn.MSELoss()(y_pred.flatten(), y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Training loss: {loss.item():.4f}")
train_state.last_loss = loss.item()
train_state.steps_trained += num_steps
new_canvas = self.plot(dataset, train_state, plot_options)
return new_canvas, train_state
def reset_model(
self,
dataset: Dataset,
architecture: Architecture,
hyperparameters: Hyperparameters,
plot_options: PlotOptions,
):
print("Resetting model")
new_model = init_model(architecture)
new_optimizer = init_optimizer(new_model, hyperparameters)
new_train_state = TrainState(new_model, new_optimizer)
new_canvas = self.plot(dataset, new_train_state, plot_options)
return new_canvas, new_train_state
def launch(self):
# build the Gradio interface
with gr.Blocks(css=self.css) as demo:
# app title
gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
# states
dataset = gr.State(Dataset())
architecture = gr.State(Architecture())
hyperparameters = gr.State(Hyperparameters())
model = init_model(architecture.value)
optimizer = init_optimizer(model, hyperparameters.value)
train_state = gr.State(TrainState(model, optimizer))
plot_options = gr.State(PlotOptions())
# GUI elements and layout
with gr.Row():
with gr.Column(scale=2):
canvas = gr.Image(
value=self.plot(dataset.value, train_state.value, plot_options.value),
# show_download_button=False,
container=True,
)
with gr.Column(scale=1):
with gr.Tab("Dataset"):
dataset_view = DatasetView()
dataset_view.build(state=dataset)
dataset.change(
fn=self.update_dataset,
inputs=[dataset, architecture, hyperparameters, plot_options],
outputs=[canvas, train_state],
)
with gr.Tab("Architecture"):
architecture_view = ArchitectureView()
architecture_view.build(state=architecture)
architecture.change(
fn=self.update_architecture,
inputs=[dataset, architecture, hyperparameters, plot_options],
outputs=[canvas, train_state],
)
with gr.Tab("Train"):
hyperparameters_view = HyperparametersView()
hyperparameters_view.build(state=hyperparameters)
hyperparameters.change(
fn=self.update_hyperparameters,
inputs=[dataset, architecture, hyperparameters, plot_options],
outputs=[canvas, train_state],
)
train_increment = gr.Number(
label="Step increment",
value=1,
precision=0,
)
train_button = gr.Button("Train step")
reset_button = gr.Button("Reset Model")
with gr.Row():
train_step = gr.Number(
label="Steps trained",
value=train_state.value.steps_trained,
interactive=False,
)
train_loss = gr.Number(
label="Last Loss",
value=train_state.value.last_loss,
interactive=False,
)
train_button.click(
fn=self.train_step,
inputs=[dataset, hyperparameters, train_state, plot_options, train_increment],
outputs=[canvas, train_state],
)
reset_button.click(
fn=self.reset_model,
inputs=[dataset, architecture, hyperparameters, plot_options],
outputs=[canvas, train_state],
)
train_state.change(
fn=lambda state: (state.steps_trained, state.last_loss),
inputs=[train_state],
outputs=[train_step, train_loss],
)
with gr.Tab("Plot Options"):
show_training_data = gr.Checkbox(
label="Show Training Data",
value=True,
)
show_true_function = gr.Checkbox(
label="Show True Function",
value=True,
)
show_prediction = gr.Checkbox(
label="Show Prediction",
value=True,
)
show_training_data.change(
fn=lambda val, options: options.update(show_training_data=val),
inputs=[show_training_data, plot_options],
outputs=[plot_options],
)
show_true_function.change(
fn=lambda val, options: options.update(show_true_function=val),
inputs=[show_true_function, plot_options],
outputs=[plot_options],
)
show_prediction.change(
fn=lambda val, options: options.update(show_prediction=val),
inputs=[show_prediction, plot_options],
outputs=[plot_options],
)
plot_options.change(
fn=self.plot,
inputs=[dataset, train_state, plot_options],
outputs=[canvas],
)
with gr.Tab("Usage"):
with open("usage.md") as f:
usage_text = f.read()
gr.Markdown(usage_text)
demo.launch()
visualizer = MlpVisualizer(width=1200, height=900)
visualizer.launch()