from dataclasses import dataclass, fields import gradio as gr @dataclass(frozen=True) class SgdHyperparameters: learning_rate: float = 0.01 momentum: float = 0.0 weight_decay: float = 0.0 batch_size: int = 32 @dataclass(frozen=True) class AdamHyperparameters: learning_rate: float = 0.001 beta1: float = 0.9 beta2: float = 0.999 weight_decay: float = 0.0 batch_size: int = 32 class Hyperparameters: def __init__( self, optimizer: str = "SGD", sgd_params: SgdHyperparameters = SgdHyperparameters(), adam_params: AdamHyperparameters = AdamHyperparameters(), ): self.optimizer = optimizer self.sgd_params = sgd_params self.adam_params = adam_params def update(self, **kwargs): return Hyperparameters( optimizer=kwargs.get("optimizer", self.optimizer), sgd_params=kwargs.get("sgd_params", self.sgd_params), adam_params=kwargs.get("adam_params", self.adam_params), ) def __hash__(self): return hash((self.optimizer, self.sgd_params, self.adam_params)) @property def batch_size(self): if self.optimizer == "SGD": return self.sgd_params.batch_size elif self.optimizer == "Adam": return self.adam_params.batch_size else: raise ValueError(f"Unknown optimizer: {self.optimizer}") class HyperparametersView: def update_optimizer_type(self, state: Hyperparameters, optimizer: str): state = state.update(optimizer=optimizer) return ( state, gr.update(visible=(optimizer == "SGD")), gr.update(visible=(optimizer == "Adam")), ) def update_sgd_hyperparameters( self, state: Hyperparameters, sgd_learning_rate: float, sgd_momentum: float, sgd_weight_decay: float, sgd_batch_size: int, ): sgd_params = SgdHyperparameters( learning_rate=sgd_learning_rate, momentum=sgd_momentum, weight_decay=sgd_weight_decay, batch_size=sgd_batch_size, ) state = state.update(sgd_params=sgd_params) return state def update_adam_hyperparameters( self, state: Hyperparameters, adam_learning_rate: float, adam_beta1: float, adam_beta2: float, adam_weight_decay: float, adam_batch_size: int, ): adam_params = AdamHyperparameters( learning_rate=adam_learning_rate, beta1=adam_beta1, beta2=adam_beta2, weight_decay=adam_weight_decay, batch_size=adam_batch_size, ) state = state.update(adam_params=adam_params) return state def build(self, state: gr.State): hyper = state.value with gr.Column(): optimizer_select = gr.Dropdown( choices=["SGD", "Adam"], value=hyper.optimizer, label="Optimizer", interactive=True, ) with gr.Group(visible=(hyper.optimizer == "SGD")) as sgd_box: sgd_components = {} with gr.Row(): for f in fields(hyper.sgd_params): sgd_components[f.name] = gr.Number( value=getattr(hyper.sgd_params, f.name), label=f.name.replace("_", " ").title(), interactive=True, ) with gr.Group(visible=(hyper.optimizer == "Adam")) as adam_box: adam_components = {} with gr.Row(): for f in fields(hyper.adam_params): adam_components[f.name] = gr.Number( value=getattr(hyper.adam_params, f.name), label=f.name.replace("_", " ").title(), interactive=True, ) optimizer_select.change( fn=self.update_optimizer_type, inputs=[state, optimizer_select], outputs=[state, sgd_box, adam_box], ) for name, component in sgd_components.items(): component.submit( fn=self.update_sgd_hyperparameters, inputs=[ state, sgd_components["learning_rate"], sgd_components["momentum"], sgd_components["weight_decay"], sgd_components["batch_size"], ], outputs=[state], ) for name, component in adam_components.items(): component.submit( fn=self.update_adam_hyperparameters, inputs=[ state, adam_components["learning_rate"], adam_components["beta1"], adam_components["beta2"], adam_components["weight_decay"], adam_components["batch_size"], ], outputs=[state], )