import io import gradio as gr import matplotlib.pyplot as plt import matplotlib.lines as mlines import numpy as np from PIL import Image import plotly.graph_objects as go from sklearn.datasets import make_regression from sklearn.linear_model import ElasticNet import logging logging.basicConfig( level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) format="%(asctime)s [%(levelname)s] %(message)s", # log format ) logger = logging.getLogger("ELVIS") from dataset import Dataset, DatasetView def min_corresponding_entries(W1, W2, w1, tol=0.1): mask = (W1 <= w1) values = W2[mask] if values.size == 0: raise ValueError("No entries in W1 less than equal to w1") return np.min(values) def l1_norm(W): return np.sum(np.abs(W), axis=-1) def l2_norm(W): return np.linalg.norm(W, axis=-1) def l1_loss(W, y, X): num_dots = W.shape[0] y = y.reshape(1, -1) preds = W.reshape(-1, 2) @ X.T return np.mean(np.abs(y - preds), axis=1).reshape(num_dots, num_dots) def l2_loss(W, y, X): num_dots = W.shape[0] y = y.reshape(1, -1) preds = W.reshape(-1, 2) @ X.T return np.mean((y - preds) ** 2, axis=1).reshape(num_dots, num_dots) def l2_loss_regularization_path(y, X, regularization_type): if regularization_type == "l2": l1_ratio = 0 alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)]) elif regularization_type == "l1": l1_ratio = 1 alphas = None else: raise ValueError("regularization_type must be 'l1' or 'l2'") _, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas) return coefs.T class Regularization: LOSS_TYPES = ['l1', 'l2'] REGULARIZER_TYPES = ['l1', 'l2'] LOSS_FUNCTIONS = { 'l1': l1_loss, 'l2': l2_loss, } REGULARIZER_FUNCTIONS = { 'l1': l1_norm, 'l2': l2_norm, } FIGURE_NAME = "loss_and_regularization_plot.svg" def __init__(self, width, height): # initialized in draw_plot #self.canvas_width = -1 #self.canvas_height = -1 self.canvas_width = width self.canvas_height = height self.css =""" .hidden-button { display: none; } """ def compute_and_plot_loss_and_reg( self, dataset: Dataset, loss_type: str, reg_type: str, reg_levels: list, w1_range: list, w2_range: list, num_dots: int, plot_path: bool, ): X = dataset.X y = dataset.y W1, W2 = self._build_parameter_grid( w1_range, w2_range, num_dots ) losses = self._compute_losses( X, y, loss_type, W1, W2 ) reg_values = self._compute_reg_values( W1, W2, reg_type ) loss_levels = [ min_corresponding_entries( reg_values, losses, reg_level ) for reg_level in reg_levels ] loss_levels.reverse() try: unregularized_w = np.linalg.solve(X.T @ X, X.T @ y) except np.linalg.LinAlgError: # the solutions are on a line eig_vals, eig_vectors = np.linalg.eigh(X.T @ X) line_direction = eig_vectors[:, np.argmin(eig_vals)] m = line_direction[1] / line_direction[0] candidate_w = np.linalg.lstsq(X, y, rcond=None)[0] b = candidate_w[1] - m * candidate_w[0] unregularized_w1 = np.linspace(w1_range[0], w1_range[1], num_dots) unregularized_w2 = m * unregularized_w1 + b unregularized_w = np.stack((unregularized_w1, unregularized_w2), axis=-1) mask = (unregularized_w2 <= w2_range[1]) & (unregularized_w2 >= w2_range[0]) unregularized_w = unregularized_w[mask] if plot_path: if loss_type == "l2": path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type) else: # one possible way that works but its rough # min_loss_reg = reg_values.ravel()[np.argmin(losses)] # path_reg_levels = np.linspace(0, min_loss_reg, 20) # path_w = [] # for reg_level in path_reg_levels: # mask = reg_values <= reg_level # if np.sum(mask) == 0: # continue # idx = np.argmin(losses[mask]) # path_w.append( # np.stack((W1, W2), axis=-1)[mask][idx] # ) # # path_w = np.array(path_w) path_w = None else: path_w = None return self.plot_loss_and_reg( W1, W2, losses, reg_values, loss_levels, reg_levels, unregularized_w, path_w, ) def plot_loss_and_reg( self, W1: np.ndarray, W2: np.ndarray, losses: np.ndarray, reg_values: np.ndarray, loss_levels: list, reg_levels: list, unregularized_w: np.ndarray, path_w: np.ndarray | None, ): fig, ax = plt.subplots(figsize=(8, 8)) ax.set_title("") ax.set_xlabel("w1") ax.set_ylabel("w2") cmap = plt.get_cmap("viridis") N = len(reg_levels) colors = [cmap(i / (N - 1)) for i in range(N)] # regularizer contours cs1 = ax.contour(W1, W2, reg_values, levels=reg_levels, colors=colors, linestyles="dashed") ax.clabel(cs1, inline=True, fontsize=8) # show contour levels # loss contours cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1]) ax.clabel(cs2, inline=True, fontsize=8) # unregularized solution if unregularized_w.ndim == 1: ax.plot(unregularized_w[0], unregularized_w[1], "bx", markersize=5, label="unregularized solution") else: ax.plot(unregularized_w[:, 0], unregularized_w[:, 1], "b-", label="unregularized solution") # regularization path if path_w is not None: ax.plot(path_w[:, 0], path_w[:, 1], "r-") # legend loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss') reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization') handles = [loss_line, reg_line] if path_w is not None: path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path') handles.append(path_line) if unregularized_w.ndim == 1: handles.append( mlines.Line2D([], [], color='blue', marker='x', linestyle='None', label='unregularized solution') ) else: handles.append( mlines.Line2D([], [], color='blue', linestyle='-', label='unregularized solution') ) ax.legend(handles=handles) ax.grid(True) buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) img = Image.open(buf) fig.savefig(f"{self.FIGURE_NAME}") return img def plot_data(self, dataset: Dataset): mesh_x1, mesh_x2, y = dataset.get_function(nsample=100) fig = go.Figure() fig.add_trace( go.Surface( z=y, x=mesh_x1, y=mesh_x2, colorscale='Viridis', opacity=0.8, name='True function', ) ) fig.add_trace( go.Scatter3d( x=dataset.X[:, 0], y=dataset.X[:, 1], z=dataset.y, mode='markers', marker=dict( size=3, color='red', opacity=0.8, symbol='circle', ), name='Data Points', ) ) fig.update_layout( title="Data", scene={ "xaxis": {"title": "X1", "nticks": 6}, "yaxis": {"title": "X2", "nticks": 6}, "zaxis": {"title": "Y", "nticks": 6}, "camera": {"eye": {"x": -1.5, "y": -1.5, "z": 1.2}}, }, width=800, height=600, ) return fig def plot_strength_vs_weight(self, dataset: Dataset, loss_type: str, reg_type: str): X = dataset.X y = dataset.y alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)]) if loss_type == "l2": l1_ratio = 1 if reg_type == "l1" else 0 alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas) else: return Image.new("RGB", (800, 800), color="white") coefs = coefs.T fig, ax = plt.subplots(figsize=(8, 8)) ax.plot(alphas, coefs[:, 0], label="w1") ax.plot(alphas, coefs[:, 1], label="w2") ax.set_xscale("log") ax.set_xlabel("Regularization strength (alpha)") ax.set_ylabel("Weight value") ax.legend() buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) img = Image.open(buf) return img def update_loss_type(self, loss_type: str): if loss_type not in self.LOSS_TYPES: raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}") return loss_type def update_reg_path_visibility(self, loss_type: str): visible = loss_type == "l2" return gr.update(visible=visible) def update_regularizer(self, reg_type: str): if reg_type not in self.REGULARIZER_TYPES: raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}") return reg_type def update_reg_levels(self, reg_levels_input: str): reg_levels = [float(reg_level) for reg_level in reg_levels_input.split(",")] return reg_levels def update_w1_range(self, w1_range_input: str): w1_range = [float(w1) for w1 in w1_range_input.split(",")] return w1_range def update_w2_range(self, w2_range_input: str): w2_range = [float(w2) for w2 in w2_range_input.split(",")] return w2_range def update_resolution(self, num_dots: int): return num_dots def update_plot_path(self, plot_path: bool): return plot_path def _build_parameter_grid( self, w1_range: list, w2_range: list, num_dots: int, ) -> tuple[np.ndarray, np.ndarray]: # build grid in parameter space w1 = np.linspace(w1_range[0], w1_range[1], num_dots) w2 = np.linspace(w2_range[0], w2_range[1], num_dots) # include (0, 0) if 0 not in w1: w1 = np.insert(w1, np.searchsorted(w1, 0), 0) if 0 not in w2: w2 = np.insert(w2, np.searchsorted(w2, 0), 0) W1, W2 = np.meshgrid(w1, w2) return W1, W2 def _compute_losses( self, X: np.ndarray, y: np.ndarray, loss_type: str, W1: np.ndarray, W2: np.ndarray, ) -> np.ndarray: stacked = np.stack((W1, W2), axis=-1) losses = self.LOSS_FUNCTIONS[loss_type](stacked, y, X) return losses def _compute_reg_values( self, W1: np.ndarray, W2: np.ndarray, reg_type: str, ) -> np.ndarray: stacked = np.stack((W1, W2), axis=-1) regs = self.REGULARIZER_FUNCTIONS[reg_type](stacked) return regs def launch(self): # build the Gradio interface with gr.Blocks(css=self.css) as demo: # app title gr.HTML("
Regularization visualizer
") # states dataset = gr.State(Dataset()) loss_type = gr.State("l2") reg_type = gr.State("l2") reg_levels = gr.State([10, 20, 30]) w1_range = gr.State([-100, 100]) w2_range = gr.State([-100, 100]) num_dots = gr.State(500) plot_regularization_path = gr.State(False) # GUI elements and layout with gr.Row(): with gr.Column(scale=2): with gr.Tab("Loss and Regularization"): self.loss_and_regularization_plot = gr.Image( value=self.compute_and_plot_loss_and_reg( dataset.value, loss_type.value, reg_type.value, reg_levels.value, w1_range.value, w2_range.value, num_dots.value, plot_regularization_path.value, ), container=True, ) with gr.Tab("Data"): self.data_3d_plot = gr.Plot( value=self.plot_data(dataset.value), container=True ) with gr.Tab("Strength vs weight"): self.strength_vs_weight = gr.Image( value=self.plot_strength_vs_weight( dataset.value, loss_type.value, reg_type.value ), container=True, ) with gr.Column(scale=1): with gr.Tab("Settings"): with gr.Row(): model_textbox = gr.Textbox( label="Model", value="y = w1 * x1 + w2 * x2", interactive=False, ) with gr.Row(): loss_type_selection = gr.Dropdown( choices=['l1', 'l2'], label='Loss type', value='l2', visible=True, ) with gr.Group(): with gr.Row(): regularizer_type_selection = gr.Dropdown( choices=['l1', 'l2'], label='Regularizer type', value='l2', visible=True, ) reg_textbox = gr.Textbox( label="Regularizer levels", value="10, 20, 30", interactive=True, ) with gr.Row(): w1_textbox = gr.Textbox( label="w1 range", value="-100, 100", interactive=True, ) w2_textbox = gr.Textbox( label="w2 range", value="-100, 100", interactive=True, ) with gr.Row(): resolution_slider = gr.Slider( minimum=100, maximum=1000, value=500, step=1, label="Resolution (#points)", ) submit_button = gr.Button("Submit changes") with gr.Row(): path_checkbox = gr.Checkbox(label="Show regularization path", value=False) with gr.Tab("Data"): dataset_view = DatasetView() dataset_view.build(state=dataset) dataset.change( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ).then( fn=self.plot_data, inputs=[dataset], outputs=self.data_3d_plot, ).then( fn=self.plot_strength_vs_weight, inputs=[ dataset, loss_type, reg_type, ], outputs=self.strength_vs_weight, ) with gr.Tab("Export"): # use hidden download button to generate files on the fly # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 with gr.Row(): btn_export_plot_loss_reg = gr.Button("Loss and Regularization Plot") btn_export_plot_loss_reg_hidden = gr.DownloadButton( label="You should not see this", elem_id="btn_export_plot_loss_reg_hidden", elem_classes="hidden-button" ) with gr.Tab("Usage"): gr.Markdown(''.join(open('usage.md', 'r').readlines())) # event handlers for GUI elements # settings loss_type_selection.change( fn=self.update_loss_type, inputs=[loss_type_selection], outputs=[loss_type], ).then( fn=self.update_reg_path_visibility, inputs=[loss_type_selection], outputs=[path_checkbox], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ).then( fn=self.plot_strength_vs_weight, inputs=[ dataset, loss_type, reg_type, ], outputs=self.strength_vs_weight, ) regularizer_type_selection.change( fn=self.update_regularizer, inputs=[regularizer_type_selection], outputs=[reg_type], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ).then( fn=self.plot_strength_vs_weight, inputs=[ dataset, loss_type, reg_type, ], outputs=self.strength_vs_weight, ) reg_textbox.submit( self.update_reg_levels, inputs=[reg_textbox], outputs=[reg_levels], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ).then( fn=self.plot_strength_vs_weight, inputs=[ dataset, loss_type, reg_type, ], outputs=self.strength_vs_weight, ) w1_textbox.submit( self.update_w1_range, inputs=[w1_textbox], outputs=[w1_range], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ) w2_textbox.submit( self.update_w2_range, inputs=[w2_textbox], outputs=[w2_range], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ) submit_button.click( self.update_w1_range, inputs=[w1_textbox], outputs=[w1_range], ).then( self.update_w2_range, inputs=[w2_textbox], outputs=[w2_range], ).then( self.update_reg_levels, inputs=[reg_textbox], outputs=[reg_levels], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ) resolution_slider.change( self.update_resolution, inputs=[resolution_slider], outputs=[num_dots], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ) path_checkbox.change( self.update_plot_path, inputs=[path_checkbox], outputs=[plot_regularization_path], ).then( fn=self.compute_and_plot_loss_and_reg, inputs=[ dataset, loss_type, reg_type, reg_levels, w1_range, w2_range, num_dots, plot_regularization_path, ], outputs=self.loss_and_regularization_plot, ) # export btn_export_plot_loss_reg.click( fn=lambda: self.FIGURE_NAME, inputs=None, outputs=[btn_export_plot_loss_reg_hidden], ).then( fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_plot_loss_reg_hidden').click()" ) demo.launch() visualizer = Regularization(width=1200, height=900) visualizer.launch()