Spaces:
Sleeping
Sleeping
File size: 5,175 Bytes
b38e4c6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | from dataclasses import dataclass, fields
import gradio as gr
@dataclass(frozen=True)
class SgdHyperparameters:
learning_rate: float = 0.01
momentum: float = 0.0
weight_decay: float = 0.0
batch_size: int = 32
@dataclass(frozen=True)
class AdamHyperparameters:
learning_rate: float = 0.001
beta1: float = 0.9
beta2: float = 0.999
weight_decay: float = 0.0
batch_size: int = 32
class Hyperparameters:
def __init__(
self,
optimizer: str = "SGD",
sgd_params: SgdHyperparameters = SgdHyperparameters(),
adam_params: AdamHyperparameters = AdamHyperparameters(),
):
self.optimizer = optimizer
self.sgd_params = sgd_params
self.adam_params = adam_params
def update(self, **kwargs):
return Hyperparameters(
optimizer=kwargs.get("optimizer", self.optimizer),
sgd_params=kwargs.get("sgd_params", self.sgd_params),
adam_params=kwargs.get("adam_params", self.adam_params),
)
def __hash__(self):
return hash((self.optimizer, self.sgd_params, self.adam_params))
@property
def batch_size(self):
if self.optimizer == "SGD":
return self.sgd_params.batch_size
elif self.optimizer == "Adam":
return self.adam_params.batch_size
else:
raise ValueError(f"Unknown optimizer: {self.optimizer}")
class HyperparametersView:
def update_optimizer_type(self, state: Hyperparameters, optimizer: str):
state = state.update(optimizer=optimizer)
return (
state,
gr.update(visible=(optimizer == "SGD")),
gr.update(visible=(optimizer == "Adam")),
)
def update_sgd_hyperparameters(
self,
state: Hyperparameters,
sgd_learning_rate: float,
sgd_momentum: float,
sgd_weight_decay: float,
sgd_batch_size: int,
):
sgd_params = SgdHyperparameters(
learning_rate=sgd_learning_rate,
momentum=sgd_momentum,
weight_decay=sgd_weight_decay,
batch_size=sgd_batch_size,
)
state = state.update(sgd_params=sgd_params)
return state
def update_adam_hyperparameters(
self,
state: Hyperparameters,
adam_learning_rate: float,
adam_beta1: float,
adam_beta2: float,
adam_weight_decay: float,
adam_batch_size: int,
):
adam_params = AdamHyperparameters(
learning_rate=adam_learning_rate,
beta1=adam_beta1,
beta2=adam_beta2,
weight_decay=adam_weight_decay,
batch_size=adam_batch_size,
)
state = state.update(adam_params=adam_params)
return state
def build(self, state: gr.State):
hyper = state.value
with gr.Column():
optimizer_select = gr.Dropdown(
choices=["SGD", "Adam"],
value=hyper.optimizer,
label="Optimizer",
interactive=True,
)
with gr.Group(visible=(hyper.optimizer == "SGD")) as sgd_box:
sgd_components = {}
with gr.Row():
for f in fields(hyper.sgd_params):
sgd_components[f.name] = gr.Number(
value=getattr(hyper.sgd_params, f.name),
label=f.name.replace("_", " ").title(),
interactive=True,
)
with gr.Group(visible=(hyper.optimizer == "Adam")) as adam_box:
adam_components = {}
with gr.Row():
for f in fields(hyper.adam_params):
adam_components[f.name] = gr.Number(
value=getattr(hyper.adam_params, f.name),
label=f.name.replace("_", " ").title(),
interactive=True,
)
optimizer_select.change(
fn=self.update_optimizer_type,
inputs=[state, optimizer_select],
outputs=[state, sgd_box, adam_box],
)
for name, component in sgd_components.items():
component.submit(
fn=self.update_sgd_hyperparameters,
inputs=[
state,
sgd_components["learning_rate"],
sgd_components["momentum"],
sgd_components["weight_decay"],
sgd_components["batch_size"],
],
outputs=[state],
)
for name, component in adam_components.items():
component.submit(
fn=self.update_adam_hyperparameters,
inputs=[
state,
adam_components["learning_rate"],
adam_components["beta1"],
adam_components["beta2"],
adam_components["weight_decay"],
adam_components["batch_size"],
],
outputs=[state],
)
|