mlp_visualizer / old /hyperparameters.py
joel-woodfield's picture
Refactor to separate frontend and backend logic
7af7098
from dataclasses import dataclass, fields
import gradio as gr
@dataclass(frozen=True)
class SgdHyperparameters:
learning_rate: float = 0.01
momentum: float = 0.0
weight_decay: float = 0.0
batch_size: int = 32
@dataclass(frozen=True)
class AdamHyperparameters:
learning_rate: float = 0.001
beta1: float = 0.9
beta2: float = 0.999
weight_decay: float = 0.0
batch_size: int = 32
class Hyperparameters:
def __init__(
self,
optimizer: str = "SGD",
sgd_params: SgdHyperparameters = SgdHyperparameters(),
adam_params: AdamHyperparameters = AdamHyperparameters(),
):
self.optimizer = optimizer
self.sgd_params = sgd_params
self.adam_params = adam_params
def update(self, **kwargs):
return Hyperparameters(
optimizer=kwargs.get("optimizer", self.optimizer),
sgd_params=kwargs.get("sgd_params", self.sgd_params),
adam_params=kwargs.get("adam_params", self.adam_params),
)
def __hash__(self):
return hash((self.optimizer, self.sgd_params, self.adam_params))
@property
def batch_size(self):
if self.optimizer == "SGD":
return self.sgd_params.batch_size
elif self.optimizer == "Adam":
return self.adam_params.batch_size
else:
raise ValueError(f"Unknown optimizer: {self.optimizer}")
class HyperparametersView:
def update_optimizer_type(self, state: Hyperparameters, optimizer: str):
state = state.update(optimizer=optimizer)
return (
state,
gr.update(visible=(optimizer == "SGD")),
gr.update(visible=(optimizer == "Adam")),
)
def update_sgd_hyperparameters(
self,
state: Hyperparameters,
sgd_learning_rate: float,
sgd_momentum: float,
sgd_weight_decay: float,
sgd_batch_size: int,
):
sgd_params = SgdHyperparameters(
learning_rate=sgd_learning_rate,
momentum=sgd_momentum,
weight_decay=sgd_weight_decay,
batch_size=sgd_batch_size,
)
state = state.update(sgd_params=sgd_params)
return state
def update_adam_hyperparameters(
self,
state: Hyperparameters,
adam_learning_rate: float,
adam_beta1: float,
adam_beta2: float,
adam_weight_decay: float,
adam_batch_size: int,
):
adam_params = AdamHyperparameters(
learning_rate=adam_learning_rate,
beta1=adam_beta1,
beta2=adam_beta2,
weight_decay=adam_weight_decay,
batch_size=adam_batch_size,
)
state = state.update(adam_params=adam_params)
return state
def build(self, state: gr.State):
hyper = state.value
with gr.Column():
optimizer_select = gr.Dropdown(
choices=["SGD", "Adam"],
value=hyper.optimizer,
label="Optimizer",
interactive=True,
)
with gr.Group(visible=(hyper.optimizer == "SGD")) as sgd_box:
sgd_components = {}
with gr.Row():
for f in fields(hyper.sgd_params):
sgd_components[f.name] = gr.Number(
value=getattr(hyper.sgd_params, f.name),
label=f.name.replace("_", " ").title(),
interactive=True,
)
with gr.Group(visible=(hyper.optimizer == "Adam")) as adam_box:
adam_components = {}
with gr.Row():
for f in fields(hyper.adam_params):
adam_components[f.name] = gr.Number(
value=getattr(hyper.adam_params, f.name),
label=f.name.replace("_", " ").title(),
interactive=True,
)
optimizer_select.change(
fn=self.update_optimizer_type,
inputs=[state, optimizer_select],
outputs=[state, sgd_box, adam_box],
)
for name, component in sgd_components.items():
component.submit(
fn=self.update_sgd_hyperparameters,
inputs=[
state,
sgd_components["learning_rate"],
sgd_components["momentum"],
sgd_components["weight_decay"],
sgd_components["batch_size"],
],
outputs=[state],
)
for name, component in adam_components.items():
component.submit(
fn=self.update_adam_hyperparameters,
inputs=[
state,
adam_components["learning_rate"],
adam_components["beta1"],
adam_components["beta2"],
adam_components["weight_decay"],
adam_components["batch_size"],
],
outputs=[state],
)