Spaces:
Sleeping
Sleeping
| from typing import Literal | |
| import gradio as gr | |
| from matplotlib.figure import Figure | |
| import sys | |
| from pathlib import Path | |
| root_dir = Path(__file__).resolve().parent.parent.parent | |
| backend_src = root_dir / "backend" / "src" | |
| if str(backend_src) not in sys.path: | |
| sys.path.append(str(backend_src)) | |
| from manager import Manager | |
| CSS = """ | |
| .hidden-button { | |
| display: none; | |
| } | |
| """ | |
| def handle_dataset_type_change(dataset_type: Literal["Generate", "CSV"]): | |
| if dataset_type == "Generate": | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ) | |
| def handle_generate_plots( | |
| manager: Manager, | |
| dataset_type: str, | |
| function: str, | |
| x1_range_input: str, | |
| x2_range_input: str, | |
| x_selection_method: str, | |
| sigma: float, | |
| nsample: int, | |
| csv_file: str, | |
| has_header: bool, | |
| x1_col: int, | |
| x2_col: int, | |
| y_col: int, | |
| loss_type: str, | |
| regularizer_type: str, | |
| resolution: int, | |
| ) -> tuple[Manager, Figure, Figure, Figure]: | |
| try: | |
| return manager.handle_generate_plots( | |
| dataset_type, | |
| function, | |
| x1_range_input, | |
| x2_range_input, | |
| x_selection_method, | |
| sigma, | |
| nsample, | |
| csv_file, | |
| has_header, | |
| x1_col, | |
| x2_col, | |
| y_col, | |
| loss_type, | |
| regularizer_type, | |
| resolution, | |
| ) | |
| except Exception as e: | |
| raise gr.Error("Error generating plots: " + str(e)) | |
| def launch(): | |
| default_dataset_type = "Generate" | |
| default_function = "-50 * x1 + 30 * x2" | |
| default_x1_range = "-1, 1" | |
| default_x2_range = "-1, 1" | |
| default_x_selection_method = "Grid" | |
| default_sigma = 0.1 | |
| default_num_points = 100 | |
| default_csv_file = "" | |
| default_has_header = False | |
| default_x1_col = 0 | |
| default_x2_col = 1 | |
| default_y_col = 2 | |
| default_loss_type = "l2" | |
| default_regularizer_type = "l2" | |
| default_resolution = 100 | |
| manager = Manager() | |
| manager, default_contour_plot, default_data_plot, default_strength_plot = manager.handle_generate_plots( | |
| default_dataset_type, | |
| default_function, | |
| default_x1_range, | |
| default_x2_range, | |
| default_x_selection_method, | |
| default_sigma, | |
| default_num_points, | |
| default_csv_file, | |
| default_has_header, | |
| default_x1_col, | |
| default_x2_col, | |
| default_y_col, | |
| default_loss_type, | |
| default_regularizer_type, | |
| default_resolution, | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>") | |
| manager_state = gr.State(manager) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Tab("Contours"): | |
| main_plot = gr.Plot(value=default_contour_plot) | |
| with gr.Tab("Data"): | |
| data_plot = gr.Plot(value=default_data_plot) | |
| with gr.Tab("Strength"): | |
| strength_plot = gr.Plot(value=default_strength_plot) | |
| with gr.Column(scale=1): | |
| with gr.Tab("Data"): | |
| with gr.Row(): | |
| dataset_type = gr.Radio( | |
| label="Dataset type", | |
| choices=["Generate", "CSV"], | |
| value=default_dataset_type, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| function = gr.Textbox( | |
| label="Function (in terms of x1 and x2)", | |
| value=default_function, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| x1_textbox = gr.Textbox( | |
| label="x1 range", | |
| value=default_x1_range, | |
| interactive=True, | |
| ) | |
| x2_textbox = gr.Textbox( | |
| label="x2 range", | |
| value=default_x2_range, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| x_selection_method = gr.Radio( | |
| label="How to select x points", | |
| choices=["Grid", "Random"], | |
| value=default_x_selection_method, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| sigma = gr.Number( | |
| label="Gaussian noise standard deviation", | |
| value=default_sigma, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| nsample = gr.Slider( | |
| label="Number of points", | |
| value=default_num_points, | |
| interactive=True, | |
| minimum=2, # todo - set to 1 after fixing weird cases | |
| maximum=100, | |
| step=1, | |
| ) | |
| with gr.Row(): | |
| csv_file = gr.File( | |
| label="Upload CSV file - must have columns: (x1, x2, y)", | |
| file_types=[".csv"], | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| has_header = gr.Checkbox( | |
| label="CSV has header row", | |
| value=default_has_header, | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| x1_col = gr.Number( | |
| label="x1 column index (0-based)", | |
| value=default_x1_col, | |
| visible=False, | |
| ) | |
| x2_col = gr.Number( | |
| label="x2 column index (0-based)", | |
| value=default_x2_col, | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| y_col = gr.Number( | |
| label="y column index (0-based)", | |
| value=default_y_col, | |
| visible=False, | |
| ) | |
| dataset_type.change( | |
| fn=handle_dataset_type_change, | |
| inputs=[dataset_type], | |
| outputs=[ | |
| function, | |
| x1_textbox, | |
| x2_textbox, | |
| x_selection_method, | |
| sigma, | |
| nsample, | |
| csv_file, | |
| has_header, | |
| x1_col, | |
| x2_col, | |
| y_col, | |
| ], | |
| ) | |
| regenerate_plots_button1 = gr.Button("Regenerate Plots") | |
| with gr.Tab("Regularization"): | |
| with gr.Row(): | |
| loss_type_dropdown = gr.Dropdown( | |
| label="Loss type", | |
| choices=["l1", "l2"], | |
| value=default_loss_type, | |
| interactive=True, | |
| ) | |
| regularizer_type_dropdown = gr.Dropdown( | |
| label="Regularizer type", | |
| choices=["l1", "l2"], | |
| value=default_regularizer_type, | |
| interactive=True, | |
| ) | |
| resolution_slider = gr.Slider( | |
| label="Grid resolution", | |
| value=default_resolution, | |
| minimum=100, | |
| maximum=400, | |
| step=1, | |
| interactive=True, | |
| ) | |
| regenerate_plots_button2 = gr.Button("Regenerate Plots") | |
| with gr.Tab("Usage"): | |
| with open(root_dir / "usage.md", "r") as f: | |
| gr.Markdown(f.read()) | |
| regenerate_plots_button1.click( | |
| fn=handle_generate_plots, | |
| inputs=[ | |
| manager_state, | |
| dataset_type, | |
| function, | |
| x1_textbox, | |
| x2_textbox, | |
| x_selection_method, | |
| sigma, | |
| nsample, | |
| csv_file, | |
| has_header, | |
| x1_col, | |
| x2_col, | |
| y_col, | |
| loss_type_dropdown, | |
| regularizer_type_dropdown, | |
| resolution_slider, | |
| ], | |
| outputs=[manager_state, main_plot, data_plot, strength_plot], | |
| ) | |
| regenerate_plots_button2.click( | |
| fn=handle_generate_plots, | |
| inputs=[ | |
| manager_state, | |
| dataset_type, | |
| function, | |
| x1_textbox, | |
| x2_textbox, | |
| x_selection_method, | |
| sigma, | |
| nsample, | |
| csv_file, | |
| has_header, | |
| x1_col, | |
| x2_col, | |
| y_col, | |
| loss_type_dropdown, | |
| regularizer_type_dropdown, | |
| resolution_slider, | |
| ], | |
| outputs=[manager_state, main_plot, data_plot, strength_plot], | |
| ) | |
| demo.launch(css=CSS) | |
| if __name__ == "__main__": | |
| launch() | |