joel-woodfield's picture
Add usage
6cac300
from typing import Literal
import gradio as gr
from matplotlib.figure import Figure
import sys
from pathlib import Path
root_dir = Path(__file__).resolve().parent.parent.parent
backend_src = root_dir / "backend" / "src"
if str(backend_src) not in sys.path:
sys.path.append(str(backend_src))
from manager import Manager
CSS = """
.hidden-button {
display: none;
}
"""
def handle_dataset_type_change(dataset_type: Literal["Generate", "CSV"]):
if dataset_type == "Generate":
return (
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
)
def handle_generate_plots(
manager: Manager,
dataset_type: str,
function: str,
x1_range_input: str,
x2_range_input: str,
x_selection_method: str,
sigma: float,
nsample: int,
csv_file: str,
has_header: bool,
x1_col: int,
x2_col: int,
y_col: int,
loss_type: str,
regularizer_type: str,
resolution: int,
) -> tuple[Manager, Figure, Figure, Figure]:
try:
return manager.handle_generate_plots(
dataset_type,
function,
x1_range_input,
x2_range_input,
x_selection_method,
sigma,
nsample,
csv_file,
has_header,
x1_col,
x2_col,
y_col,
loss_type,
regularizer_type,
resolution,
)
except Exception as e:
raise gr.Error("Error generating plots: " + str(e))
def launch():
default_dataset_type = "Generate"
default_function = "-50 * x1 + 30 * x2"
default_x1_range = "-1, 1"
default_x2_range = "-1, 1"
default_x_selection_method = "Grid"
default_sigma = 0.1
default_num_points = 100
default_csv_file = ""
default_has_header = False
default_x1_col = 0
default_x2_col = 1
default_y_col = 2
default_loss_type = "l2"
default_regularizer_type = "l2"
default_resolution = 100
manager = Manager()
manager, default_contour_plot, default_data_plot, default_strength_plot = manager.handle_generate_plots(
default_dataset_type,
default_function,
default_x1_range,
default_x2_range,
default_x_selection_method,
default_sigma,
default_num_points,
default_csv_file,
default_has_header,
default_x1_col,
default_x2_col,
default_y_col,
default_loss_type,
default_regularizer_type,
default_resolution,
)
with gr.Blocks() as demo:
gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
manager_state = gr.State(manager)
with gr.Row():
with gr.Column(scale=2):
with gr.Tab("Contours"):
main_plot = gr.Plot(value=default_contour_plot)
with gr.Tab("Data"):
data_plot = gr.Plot(value=default_data_plot)
with gr.Tab("Strength"):
strength_plot = gr.Plot(value=default_strength_plot)
with gr.Column(scale=1):
with gr.Tab("Data"):
with gr.Row():
dataset_type = gr.Radio(
label="Dataset type",
choices=["Generate", "CSV"],
value=default_dataset_type,
interactive=True,
)
with gr.Row():
function = gr.Textbox(
label="Function (in terms of x1 and x2)",
value=default_function,
interactive=True,
)
with gr.Row():
x1_textbox = gr.Textbox(
label="x1 range",
value=default_x1_range,
interactive=True,
)
x2_textbox = gr.Textbox(
label="x2 range",
value=default_x2_range,
interactive=True,
)
with gr.Row():
x_selection_method = gr.Radio(
label="How to select x points",
choices=["Grid", "Random"],
value=default_x_selection_method,
interactive=True,
)
with gr.Row():
sigma = gr.Number(
label="Gaussian noise standard deviation",
value=default_sigma,
interactive=True,
)
with gr.Row():
nsample = gr.Slider(
label="Number of points",
value=default_num_points,
interactive=True,
minimum=2, # todo - set to 1 after fixing weird cases
maximum=100,
step=1,
)
with gr.Row():
csv_file = gr.File(
label="Upload CSV file - must have columns: (x1, x2, y)",
file_types=[".csv"],
visible=False,
)
with gr.Row():
has_header = gr.Checkbox(
label="CSV has header row",
value=default_has_header,
visible=False,
)
with gr.Row():
x1_col = gr.Number(
label="x1 column index (0-based)",
value=default_x1_col,
visible=False,
)
x2_col = gr.Number(
label="x2 column index (0-based)",
value=default_x2_col,
visible=False,
)
with gr.Row():
y_col = gr.Number(
label="y column index (0-based)",
value=default_y_col,
visible=False,
)
dataset_type.change(
fn=handle_dataset_type_change,
inputs=[dataset_type],
outputs=[
function,
x1_textbox,
x2_textbox,
x_selection_method,
sigma,
nsample,
csv_file,
has_header,
x1_col,
x2_col,
y_col,
],
)
regenerate_plots_button1 = gr.Button("Regenerate Plots")
with gr.Tab("Regularization"):
with gr.Row():
loss_type_dropdown = gr.Dropdown(
label="Loss type",
choices=["l1", "l2"],
value=default_loss_type,
interactive=True,
)
regularizer_type_dropdown = gr.Dropdown(
label="Regularizer type",
choices=["l1", "l2"],
value=default_regularizer_type,
interactive=True,
)
resolution_slider = gr.Slider(
label="Grid resolution",
value=default_resolution,
minimum=100,
maximum=400,
step=1,
interactive=True,
)
regenerate_plots_button2 = gr.Button("Regenerate Plots")
with gr.Tab("Usage"):
with open(root_dir / "usage.md", "r") as f:
gr.Markdown(f.read())
regenerate_plots_button1.click(
fn=handle_generate_plots,
inputs=[
manager_state,
dataset_type,
function,
x1_textbox,
x2_textbox,
x_selection_method,
sigma,
nsample,
csv_file,
has_header,
x1_col,
x2_col,
y_col,
loss_type_dropdown,
regularizer_type_dropdown,
resolution_slider,
],
outputs=[manager_state, main_plot, data_plot, strength_plot],
)
regenerate_plots_button2.click(
fn=handle_generate_plots,
inputs=[
manager_state,
dataset_type,
function,
x1_textbox,
x2_textbox,
x_selection_method,
sigma,
nsample,
csv_file,
has_header,
x1_col,
x2_col,
y_col,
loss_type_dropdown,
regularizer_type_dropdown,
resolution_slider,
],
outputs=[manager_state, main_plot, data_plot, strength_plot],
)
demo.launch(css=CSS)
if __name__ == "__main__":
launch()