Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, fields | |
| import gradio as gr | |
| class SgdHyperparameters: | |
| learning_rate: float = 0.01 | |
| momentum: float = 0.0 | |
| weight_decay: float = 0.0 | |
| batch_size: int = 32 | |
| class AdamHyperparameters: | |
| learning_rate: float = 0.001 | |
| beta1: float = 0.9 | |
| beta2: float = 0.999 | |
| weight_decay: float = 0.0 | |
| batch_size: int = 32 | |
| class Hyperparameters: | |
| def __init__( | |
| self, | |
| optimizer: str = "SGD", | |
| sgd_params: SgdHyperparameters = SgdHyperparameters(), | |
| adam_params: AdamHyperparameters = AdamHyperparameters(), | |
| ): | |
| self.optimizer = optimizer | |
| self.sgd_params = sgd_params | |
| self.adam_params = adam_params | |
| def update(self, **kwargs): | |
| return Hyperparameters( | |
| optimizer=kwargs.get("optimizer", self.optimizer), | |
| sgd_params=kwargs.get("sgd_params", self.sgd_params), | |
| adam_params=kwargs.get("adam_params", self.adam_params), | |
| ) | |
| def __hash__(self): | |
| return hash((self.optimizer, self.sgd_params, self.adam_params)) | |
| def batch_size(self): | |
| if self.optimizer == "SGD": | |
| return self.sgd_params.batch_size | |
| elif self.optimizer == "Adam": | |
| return self.adam_params.batch_size | |
| else: | |
| raise ValueError(f"Unknown optimizer: {self.optimizer}") | |
| class HyperparametersView: | |
| def update_optimizer_type(self, state: Hyperparameters, optimizer: str): | |
| state = state.update(optimizer=optimizer) | |
| return ( | |
| state, | |
| gr.update(visible=(optimizer == "SGD")), | |
| gr.update(visible=(optimizer == "Adam")), | |
| ) | |
| def update_sgd_hyperparameters( | |
| self, | |
| state: Hyperparameters, | |
| sgd_learning_rate: float, | |
| sgd_momentum: float, | |
| sgd_weight_decay: float, | |
| sgd_batch_size: int, | |
| ): | |
| sgd_params = SgdHyperparameters( | |
| learning_rate=sgd_learning_rate, | |
| momentum=sgd_momentum, | |
| weight_decay=sgd_weight_decay, | |
| batch_size=sgd_batch_size, | |
| ) | |
| state = state.update(sgd_params=sgd_params) | |
| return state | |
| def update_adam_hyperparameters( | |
| self, | |
| state: Hyperparameters, | |
| adam_learning_rate: float, | |
| adam_beta1: float, | |
| adam_beta2: float, | |
| adam_weight_decay: float, | |
| adam_batch_size: int, | |
| ): | |
| adam_params = AdamHyperparameters( | |
| learning_rate=adam_learning_rate, | |
| beta1=adam_beta1, | |
| beta2=adam_beta2, | |
| weight_decay=adam_weight_decay, | |
| batch_size=adam_batch_size, | |
| ) | |
| state = state.update(adam_params=adam_params) | |
| return state | |
| def build(self, state: gr.State): | |
| hyper = state.value | |
| with gr.Column(): | |
| optimizer_select = gr.Dropdown( | |
| choices=["SGD", "Adam"], | |
| value=hyper.optimizer, | |
| label="Optimizer", | |
| interactive=True, | |
| ) | |
| with gr.Group(visible=(hyper.optimizer == "SGD")) as sgd_box: | |
| sgd_components = {} | |
| with gr.Row(): | |
| for f in fields(hyper.sgd_params): | |
| sgd_components[f.name] = gr.Number( | |
| value=getattr(hyper.sgd_params, f.name), | |
| label=f.name.replace("_", " ").title(), | |
| interactive=True, | |
| ) | |
| with gr.Group(visible=(hyper.optimizer == "Adam")) as adam_box: | |
| adam_components = {} | |
| with gr.Row(): | |
| for f in fields(hyper.adam_params): | |
| adam_components[f.name] = gr.Number( | |
| value=getattr(hyper.adam_params, f.name), | |
| label=f.name.replace("_", " ").title(), | |
| interactive=True, | |
| ) | |
| optimizer_select.change( | |
| fn=self.update_optimizer_type, | |
| inputs=[state, optimizer_select], | |
| outputs=[state, sgd_box, adam_box], | |
| ) | |
| for name, component in sgd_components.items(): | |
| component.submit( | |
| fn=self.update_sgd_hyperparameters, | |
| inputs=[ | |
| state, | |
| sgd_components["learning_rate"], | |
| sgd_components["momentum"], | |
| sgd_components["weight_decay"], | |
| sgd_components["batch_size"], | |
| ], | |
| outputs=[state], | |
| ) | |
| for name, component in adam_components.items(): | |
| component.submit( | |
| fn=self.update_adam_hyperparameters, | |
| inputs=[ | |
| state, | |
| adam_components["learning_rate"], | |
| adam_components["beta1"], | |
| adam_components["beta2"], | |
| adam_components["weight_decay"], | |
| adam_components["batch_size"], | |
| ], | |
| outputs=[state], | |
| ) | |