regularization / old /regularization.py
joel-woodfield's picture
Refactor code to separate frontend and backend
770d448
import io
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import numpy as np
from PIL import Image
import plotly.graph_objects as go
from sklearn.datasets import make_regression
from sklearn.linear_model import ElasticNet
import logging
logging.basicConfig(
level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
format="%(asctime)s [%(levelname)s] %(message)s", # log format
)
logger = logging.getLogger("ELVIS")
from dataset import Dataset, DatasetView
def min_corresponding_entries(W1, W2, w1, tol=0.1):
mask = (W1 <= w1)
values = W2[mask]
if values.size == 0:
raise ValueError("No entries in W1 less than equal to w1")
return np.min(values)
def l1_norm(W):
return np.sum(np.abs(W), axis=-1)
def l2_norm(W):
return np.linalg.norm(W, axis=-1)
def l1_loss(W, y, X):
num_dots = W.shape[0]
y = y.reshape(1, -1)
preds = W.reshape(-1, 2) @ X.T
return np.mean(np.abs(y - preds), axis=1).reshape(num_dots, num_dots)
def l2_loss(W, y, X):
num_dots = W.shape[0]
y = y.reshape(1, -1)
preds = W.reshape(-1, 2) @ X.T
return np.mean((y - preds) ** 2, axis=1).reshape(num_dots, num_dots)
def l2_loss_regularization_path(y, X, regularization_type):
if regularization_type == "l2":
l1_ratio = 0
alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
elif regularization_type == "l1":
l1_ratio = 1
alphas = None
else:
raise ValueError("regularization_type must be 'l1' or 'l2'")
_, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
return coefs.T
class Regularization:
LOSS_TYPES = ['l1', 'l2']
REGULARIZER_TYPES = ['l1', 'l2']
LOSS_FUNCTIONS = {
'l1': l1_loss,
'l2': l2_loss,
}
REGULARIZER_FUNCTIONS = {
'l1': l1_norm,
'l2': l2_norm,
}
FIGURE_NAME = "loss_and_regularization_plot.svg"
def __init__(self, width, height):
# initialized in draw_plot
#self.canvas_width = -1
#self.canvas_height = -1
self.canvas_width = width
self.canvas_height = height
self.css ="""
.hidden-button {
display: none;
}
"""
def compute_and_plot_loss_and_reg(
self,
dataset: Dataset,
loss_type: str,
reg_type: str,
reg_levels: list,
w1_range: list,
w2_range: list,
num_dots: int,
plot_path: bool,
):
X = dataset.X
y = dataset.y
W1, W2 = self._build_parameter_grid(
w1_range, w2_range, num_dots
)
losses = self._compute_losses(
X, y, loss_type, W1, W2
)
reg_values = self._compute_reg_values(
W1, W2, reg_type
)
loss_levels = [
min_corresponding_entries(
reg_values, losses, reg_level
)
for reg_level in reg_levels
]
loss_levels.reverse()
try:
unregularized_w = np.linalg.solve(X.T @ X, X.T @ y)
except np.linalg.LinAlgError:
# the solutions are on a line
eig_vals, eig_vectors = np.linalg.eigh(X.T @ X)
line_direction = eig_vectors[:, np.argmin(eig_vals)]
m = line_direction[1] / line_direction[0]
candidate_w = np.linalg.lstsq(X, y, rcond=None)[0]
b = candidate_w[1] - m * candidate_w[0]
unregularized_w1 = np.linspace(w1_range[0], w1_range[1], num_dots)
unregularized_w2 = m * unregularized_w1 + b
unregularized_w = np.stack((unregularized_w1, unregularized_w2), axis=-1)
mask = (unregularized_w2 <= w2_range[1]) & (unregularized_w2 >= w2_range[0])
unregularized_w = unregularized_w[mask]
if plot_path:
if loss_type == "l2":
path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type)
else:
# one possible way that works but its rough
# min_loss_reg = reg_values.ravel()[np.argmin(losses)]
# path_reg_levels = np.linspace(0, min_loss_reg, 20)
# path_w = []
# for reg_level in path_reg_levels:
# mask = reg_values <= reg_level
# if np.sum(mask) == 0:
# continue
# idx = np.argmin(losses[mask])
# path_w.append(
# np.stack((W1, W2), axis=-1)[mask][idx]
# )
#
# path_w = np.array(path_w)
path_w = None
else:
path_w = None
return self.plot_loss_and_reg(
W1,
W2,
losses,
reg_values,
loss_levels,
reg_levels,
unregularized_w,
path_w,
)
def plot_loss_and_reg(
self,
W1: np.ndarray,
W2: np.ndarray,
losses: np.ndarray,
reg_values: np.ndarray,
loss_levels: list,
reg_levels: list,
unregularized_w: np.ndarray,
path_w: np.ndarray | None,
):
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_title("")
ax.set_xlabel("w1")
ax.set_ylabel("w2")
cmap = plt.get_cmap("viridis")
N = len(reg_levels)
colors = [cmap(i / (N - 1)) for i in range(N)]
# regularizer contours
cs1 = ax.contour(W1, W2, reg_values, levels=reg_levels, colors=colors, linestyles="dashed")
ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
# loss contours
cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
ax.clabel(cs2, inline=True, fontsize=8)
# unregularized solution
if unregularized_w.ndim == 1:
ax.plot(unregularized_w[0], unregularized_w[1], "bx", markersize=5, label="unregularized solution")
else:
ax.plot(unregularized_w[:, 0], unregularized_w[:, 1], "b-", label="unregularized solution")
# regularization path
if path_w is not None:
ax.plot(path_w[:, 0], path_w[:, 1], "r-")
# legend
loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
handles = [loss_line, reg_line]
if path_w is not None:
path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
handles.append(path_line)
if unregularized_w.ndim == 1:
handles.append(
mlines.Line2D([], [], color='blue', marker='x', linestyle='None', label='unregularized solution')
)
else:
handles.append(
mlines.Line2D([], [], color='blue', linestyle='-', label='unregularized solution')
)
ax.legend(handles=handles)
ax.grid(True)
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
img = Image.open(buf)
fig.savefig(f"{self.FIGURE_NAME}")
return img
def plot_data(self, dataset: Dataset):
mesh_x1, mesh_x2, y = dataset.get_function(nsample=100)
fig = go.Figure()
fig.add_trace(
go.Surface(
z=y,
x=mesh_x1,
y=mesh_x2,
colorscale='Viridis',
opacity=0.8,
name='True function',
)
)
fig.add_trace(
go.Scatter3d(
x=dataset.X[:, 0],
y=dataset.X[:, 1],
z=dataset.y,
mode='markers',
marker=dict(
size=3,
color='red',
opacity=0.8,
symbol='circle',
),
name='Data Points',
)
)
fig.update_layout(
title="Data",
scene={
"xaxis": {"title": "X1", "nticks": 6},
"yaxis": {"title": "X2", "nticks": 6},
"zaxis": {"title": "Y", "nticks": 6},
"camera": {"eye": {"x": -1.5, "y": -1.5, "z": 1.2}},
},
width=800,
height=600,
)
return fig
def plot_strength_vs_weight(self, dataset: Dataset, loss_type: str, reg_type: str):
X = dataset.X
y = dataset.y
alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
if loss_type == "l2":
l1_ratio = 1 if reg_type == "l1" else 0
alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
else:
return Image.new("RGB", (800, 800), color="white")
coefs = coefs.T
fig, ax = plt.subplots(figsize=(8, 8))
ax.plot(alphas, coefs[:, 0], label="w1")
ax.plot(alphas, coefs[:, 1], label="w2")
ax.set_xscale("log")
ax.set_xlabel("Regularization strength (alpha)")
ax.set_ylabel("Weight value")
ax.legend()
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
img = Image.open(buf)
return img
def update_loss_type(self, loss_type: str):
if loss_type not in self.LOSS_TYPES:
raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}")
return loss_type
def update_reg_path_visibility(self, loss_type: str):
visible = loss_type == "l2"
return gr.update(visible=visible)
def update_regularizer(self, reg_type: str):
if reg_type not in self.REGULARIZER_TYPES:
raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}")
return reg_type
def update_reg_levels(self, reg_levels_input: str):
reg_levels = [float(reg_level) for reg_level in reg_levels_input.split(",")]
return reg_levels
def update_w1_range(self, w1_range_input: str):
w1_range = [float(w1) for w1 in w1_range_input.split(",")]
return w1_range
def update_w2_range(self, w2_range_input: str):
w2_range = [float(w2) for w2 in w2_range_input.split(",")]
return w2_range
def update_resolution(self, num_dots: int):
return num_dots
def update_plot_path(self, plot_path: bool):
return plot_path
def _build_parameter_grid(
self,
w1_range: list,
w2_range: list,
num_dots: int,
) -> tuple[np.ndarray, np.ndarray]:
# build grid in parameter space
w1 = np.linspace(w1_range[0], w1_range[1], num_dots)
w2 = np.linspace(w2_range[0], w2_range[1], num_dots)
# include (0, 0)
if 0 not in w1:
w1 = np.insert(w1, np.searchsorted(w1, 0), 0)
if 0 not in w2:
w2 = np.insert(w2, np.searchsorted(w2, 0), 0)
W1, W2 = np.meshgrid(w1, w2)
return W1, W2
def _compute_losses(
self,
X: np.ndarray,
y: np.ndarray,
loss_type: str,
W1: np.ndarray,
W2: np.ndarray,
) -> np.ndarray:
stacked = np.stack((W1, W2), axis=-1)
losses = self.LOSS_FUNCTIONS[loss_type](stacked, y, X)
return losses
def _compute_reg_values(
self,
W1: np.ndarray,
W2: np.ndarray,
reg_type: str,
) -> np.ndarray:
stacked = np.stack((W1, W2), axis=-1)
regs = self.REGULARIZER_FUNCTIONS[reg_type](stacked)
return regs
def launch(self):
# build the Gradio interface
with gr.Blocks(css=self.css) as demo:
# app title
gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
# states
dataset = gr.State(Dataset())
loss_type = gr.State("l2")
reg_type = gr.State("l2")
reg_levels = gr.State([10, 20, 30])
w1_range = gr.State([-100, 100])
w2_range = gr.State([-100, 100])
num_dots = gr.State(500)
plot_regularization_path = gr.State(False)
# GUI elements and layout
with gr.Row():
with gr.Column(scale=2):
with gr.Tab("Loss and Regularization"):
self.loss_and_regularization_plot = gr.Image(
value=self.compute_and_plot_loss_and_reg(
dataset.value,
loss_type.value,
reg_type.value,
reg_levels.value,
w1_range.value,
w2_range.value,
num_dots.value,
plot_regularization_path.value,
),
container=True,
)
with gr.Tab("Data"):
self.data_3d_plot = gr.Plot(
value=self.plot_data(dataset.value), container=True
)
with gr.Tab("Strength vs weight"):
self.strength_vs_weight = gr.Image(
value=self.plot_strength_vs_weight(
dataset.value, loss_type.value, reg_type.value
),
container=True,
)
with gr.Column(scale=1):
with gr.Tab("Settings"):
with gr.Row():
model_textbox = gr.Textbox(
label="Model",
value="y = w1 * x1 + w2 * x2",
interactive=False,
)
with gr.Row():
loss_type_selection = gr.Dropdown(
choices=['l1', 'l2'],
label='Loss type',
value='l2',
visible=True,
)
with gr.Group():
with gr.Row():
regularizer_type_selection = gr.Dropdown(
choices=['l1', 'l2'],
label='Regularizer type',
value='l2',
visible=True,
)
reg_textbox = gr.Textbox(
label="Regularizer levels",
value="10, 20, 30",
interactive=True,
)
with gr.Row():
w1_textbox = gr.Textbox(
label="w1 range",
value="-100, 100",
interactive=True,
)
w2_textbox = gr.Textbox(
label="w2 range",
value="-100, 100",
interactive=True,
)
with gr.Row():
resolution_slider = gr.Slider(
minimum=100,
maximum=1000,
value=500,
step=1,
label="Resolution (#points)",
)
submit_button = gr.Button("Submit changes")
with gr.Row():
path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
with gr.Tab("Data"):
dataset_view = DatasetView()
dataset_view.build(state=dataset)
dataset.change(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
).then(
fn=self.plot_data,
inputs=[dataset],
outputs=self.data_3d_plot,
).then(
fn=self.plot_strength_vs_weight,
inputs=[
dataset,
loss_type,
reg_type,
],
outputs=self.strength_vs_weight,
)
with gr.Tab("Export"):
# use hidden download button to generate files on the fly
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
with gr.Row():
btn_export_plot_loss_reg = gr.Button("Loss and Regularization Plot")
btn_export_plot_loss_reg_hidden = gr.DownloadButton(
label="You should not see this",
elem_id="btn_export_plot_loss_reg_hidden",
elem_classes="hidden-button"
)
with gr.Tab("Usage"):
gr.Markdown(''.join(open('usage.md', 'r').readlines()))
# event handlers for GUI elements
# settings
loss_type_selection.change(
fn=self.update_loss_type,
inputs=[loss_type_selection],
outputs=[loss_type],
).then(
fn=self.update_reg_path_visibility,
inputs=[loss_type_selection],
outputs=[path_checkbox],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
).then(
fn=self.plot_strength_vs_weight,
inputs=[
dataset,
loss_type,
reg_type,
],
outputs=self.strength_vs_weight,
)
regularizer_type_selection.change(
fn=self.update_regularizer,
inputs=[regularizer_type_selection],
outputs=[reg_type],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
).then(
fn=self.plot_strength_vs_weight,
inputs=[
dataset,
loss_type,
reg_type,
],
outputs=self.strength_vs_weight,
)
reg_textbox.submit(
self.update_reg_levels,
inputs=[reg_textbox],
outputs=[reg_levels],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
).then(
fn=self.plot_strength_vs_weight,
inputs=[
dataset,
loss_type,
reg_type,
],
outputs=self.strength_vs_weight,
)
w1_textbox.submit(
self.update_w1_range,
inputs=[w1_textbox],
outputs=[w1_range],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
)
w2_textbox.submit(
self.update_w2_range,
inputs=[w2_textbox],
outputs=[w2_range],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
)
submit_button.click(
self.update_w1_range,
inputs=[w1_textbox],
outputs=[w1_range],
).then(
self.update_w2_range,
inputs=[w2_textbox],
outputs=[w2_range],
).then(
self.update_reg_levels,
inputs=[reg_textbox],
outputs=[reg_levels],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
)
resolution_slider.change(
self.update_resolution,
inputs=[resolution_slider],
outputs=[num_dots],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
)
path_checkbox.change(
self.update_plot_path,
inputs=[path_checkbox],
outputs=[plot_regularization_path],
).then(
fn=self.compute_and_plot_loss_and_reg,
inputs=[
dataset,
loss_type,
reg_type,
reg_levels,
w1_range,
w2_range,
num_dots,
plot_regularization_path,
],
outputs=self.loss_and_regularization_plot,
)
# export
btn_export_plot_loss_reg.click(
fn=lambda: self.FIGURE_NAME,
inputs=None,
outputs=[btn_export_plot_loss_reg_hidden],
).then(
fn=None,
inputs=None,
outputs=None,
js="() => document.querySelector('#btn_export_plot_loss_reg_hidden').click()"
)
demo.launch()
visualizer = Regularization(width=1200, height=900)
visualizer.launch()