from typing import Literal import gradio as gr from matplotlib.figure import Figure import sys from pathlib import Path root_dir = Path(__file__).resolve().parent.parent.parent backend_src = root_dir / "backend" / "src" if str(backend_src) not in sys.path: sys.path.append(str(backend_src)) from manager import Manager CSS = """ .hidden-button { display: none; } """ def handle_dataset_type_change(dataset_type: Literal["Generate", "CSV"]): if dataset_type == "Generate": return ( gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ) def handle_generate_plots( manager: Manager, dataset_type: str, function: str, x1_range_input: str, x2_range_input: str, x_selection_method: str, sigma: float, nsample: int, csv_file: str, has_header: bool, x1_col: int, x2_col: int, y_col: int, loss_type: str, regularizer_type: str, resolution: int, ) -> tuple[Manager, Figure, Figure, Figure]: try: return manager.handle_generate_plots( dataset_type, function, x1_range_input, x2_range_input, x_selection_method, sigma, nsample, csv_file, has_header, x1_col, x2_col, y_col, loss_type, regularizer_type, resolution, ) except Exception as e: raise gr.Error("Error generating plots: " + str(e)) def launch(): default_dataset_type = "Generate" default_function = "-50 * x1 + 30 * x2" default_x1_range = "-1, 1" default_x2_range = "-1, 1" default_x_selection_method = "Grid" default_sigma = 0.1 default_num_points = 100 default_csv_file = "" default_has_header = False default_x1_col = 0 default_x2_col = 1 default_y_col = 2 default_loss_type = "l2" default_regularizer_type = "l2" default_resolution = 100 manager = Manager() manager, default_contour_plot, default_data_plot, default_strength_plot = manager.handle_generate_plots( default_dataset_type, default_function, default_x1_range, default_x2_range, default_x_selection_method, default_sigma, default_num_points, default_csv_file, default_has_header, default_x1_col, default_x2_col, default_y_col, default_loss_type, default_regularizer_type, default_resolution, ) with gr.Blocks() as demo: gr.HTML("
Regularization visualizer
") manager_state = gr.State(manager) with gr.Row(): with gr.Column(scale=2): with gr.Tab("Contours"): main_plot = gr.Plot(value=default_contour_plot) with gr.Tab("Data"): data_plot = gr.Plot(value=default_data_plot) with gr.Tab("Strength"): strength_plot = gr.Plot(value=default_strength_plot) with gr.Column(scale=1): with gr.Tab("Data"): with gr.Row(): dataset_type = gr.Radio( label="Dataset type", choices=["Generate", "CSV"], value=default_dataset_type, interactive=True, ) with gr.Row(): function = gr.Textbox( label="Function (in terms of x1 and x2)", value=default_function, interactive=True, ) with gr.Row(): x1_textbox = gr.Textbox( label="x1 range", value=default_x1_range, interactive=True, ) x2_textbox = gr.Textbox( label="x2 range", value=default_x2_range, interactive=True, ) with gr.Row(): x_selection_method = gr.Radio( label="How to select x points", choices=["Grid", "Random"], value=default_x_selection_method, interactive=True, ) with gr.Row(): sigma = gr.Number( label="Gaussian noise standard deviation", value=default_sigma, interactive=True, ) with gr.Row(): nsample = gr.Slider( label="Number of points", value=default_num_points, interactive=True, minimum=2, # todo - set to 1 after fixing weird cases maximum=100, step=1, ) with gr.Row(): csv_file = gr.File( label="Upload CSV file - must have columns: (x1, x2, y)", file_types=[".csv"], visible=False, ) with gr.Row(): has_header = gr.Checkbox( label="CSV has header row", value=default_has_header, visible=False, ) with gr.Row(): x1_col = gr.Number( label="x1 column index (0-based)", value=default_x1_col, visible=False, ) x2_col = gr.Number( label="x2 column index (0-based)", value=default_x2_col, visible=False, ) with gr.Row(): y_col = gr.Number( label="y column index (0-based)", value=default_y_col, visible=False, ) dataset_type.change( fn=handle_dataset_type_change, inputs=[dataset_type], outputs=[ function, x1_textbox, x2_textbox, x_selection_method, sigma, nsample, csv_file, has_header, x1_col, x2_col, y_col, ], ) regenerate_plots_button1 = gr.Button("Regenerate Plots") with gr.Tab("Regularization"): with gr.Row(): loss_type_dropdown = gr.Dropdown( label="Loss type", choices=["l1", "l2"], value=default_loss_type, interactive=True, ) regularizer_type_dropdown = gr.Dropdown( label="Regularizer type", choices=["l1", "l2"], value=default_regularizer_type, interactive=True, ) resolution_slider = gr.Slider( label="Grid resolution", value=default_resolution, minimum=100, maximum=400, step=1, interactive=True, ) regenerate_plots_button2 = gr.Button("Regenerate Plots") with gr.Tab("Usage"): with open(root_dir / "usage.md", "r") as f: gr.Markdown(f.read()) regenerate_plots_button1.click( fn=handle_generate_plots, inputs=[ manager_state, dataset_type, function, x1_textbox, x2_textbox, x_selection_method, sigma, nsample, csv_file, has_header, x1_col, x2_col, y_col, loss_type_dropdown, regularizer_type_dropdown, resolution_slider, ], outputs=[manager_state, main_plot, data_plot, strength_plot], ) regenerate_plots_button2.click( fn=handle_generate_plots, inputs=[ manager_state, dataset_type, function, x1_textbox, x2_textbox, x_selection_method, sigma, nsample, csv_file, has_header, x1_col, x2_col, y_col, loss_type_dropdown, regularizer_type_dropdown, resolution_slider, ], outputs=[manager_state, main_plot, data_plot, strength_plot], ) demo.launch(css=CSS) if __name__ == "__main__": launch()