joel-woodfield commited on
Commit
b38e4c6
·
1 Parent(s): 41b207c

Add model training options and show model predictions on plot

Browse files
Files changed (2) hide show
  1. hyperparameters.py +161 -0
  2. mlp_visualizer.py +165 -12
hyperparameters.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields
2
+ import gradio as gr
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class SgdHyperparameters:
7
+ learning_rate: float = 0.01
8
+ momentum: float = 0.0
9
+ weight_decay: float = 0.0
10
+ batch_size: int = 32
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class AdamHyperparameters:
15
+ learning_rate: float = 0.001
16
+ beta1: float = 0.9
17
+ beta2: float = 0.999
18
+ weight_decay: float = 0.0
19
+ batch_size: int = 32
20
+
21
+
22
+ class Hyperparameters:
23
+ def __init__(
24
+ self,
25
+ optimizer: str = "SGD",
26
+ sgd_params: SgdHyperparameters = SgdHyperparameters(),
27
+ adam_params: AdamHyperparameters = AdamHyperparameters(),
28
+ ):
29
+ self.optimizer = optimizer
30
+ self.sgd_params = sgd_params
31
+ self.adam_params = adam_params
32
+
33
+ def update(self, **kwargs):
34
+ return Hyperparameters(
35
+ optimizer=kwargs.get("optimizer", self.optimizer),
36
+ sgd_params=kwargs.get("sgd_params", self.sgd_params),
37
+ adam_params=kwargs.get("adam_params", self.adam_params),
38
+ )
39
+
40
+ def __hash__(self):
41
+ return hash((self.optimizer, self.sgd_params, self.adam_params))
42
+
43
+ @property
44
+ def batch_size(self):
45
+ if self.optimizer == "SGD":
46
+ return self.sgd_params.batch_size
47
+ elif self.optimizer == "Adam":
48
+ return self.adam_params.batch_size
49
+ else:
50
+ raise ValueError(f"Unknown optimizer: {self.optimizer}")
51
+
52
+
53
+ class HyperparametersView:
54
+ def update_optimizer_type(self, state: Hyperparameters, optimizer: str):
55
+ state = state.update(optimizer=optimizer)
56
+ return (
57
+ state,
58
+ gr.update(visible=(optimizer == "SGD")),
59
+ gr.update(visible=(optimizer == "Adam")),
60
+ )
61
+
62
+ def update_sgd_hyperparameters(
63
+ self,
64
+ state: Hyperparameters,
65
+ sgd_learning_rate: float,
66
+ sgd_momentum: float,
67
+ sgd_weight_decay: float,
68
+ sgd_batch_size: int,
69
+ ):
70
+ sgd_params = SgdHyperparameters(
71
+ learning_rate=sgd_learning_rate,
72
+ momentum=sgd_momentum,
73
+ weight_decay=sgd_weight_decay,
74
+ batch_size=sgd_batch_size,
75
+ )
76
+ state = state.update(sgd_params=sgd_params)
77
+ return state
78
+
79
+ def update_adam_hyperparameters(
80
+ self,
81
+ state: Hyperparameters,
82
+ adam_learning_rate: float,
83
+ adam_beta1: float,
84
+ adam_beta2: float,
85
+ adam_weight_decay: float,
86
+ adam_batch_size: int,
87
+ ):
88
+ adam_params = AdamHyperparameters(
89
+ learning_rate=adam_learning_rate,
90
+ beta1=adam_beta1,
91
+ beta2=adam_beta2,
92
+ weight_decay=adam_weight_decay,
93
+ batch_size=adam_batch_size,
94
+ )
95
+ state = state.update(adam_params=adam_params)
96
+ return state
97
+
98
+ def build(self, state: gr.State):
99
+ hyper = state.value
100
+ with gr.Column():
101
+ optimizer_select = gr.Dropdown(
102
+ choices=["SGD", "Adam"],
103
+ value=hyper.optimizer,
104
+ label="Optimizer",
105
+ interactive=True,
106
+ )
107
+
108
+ with gr.Group(visible=(hyper.optimizer == "SGD")) as sgd_box:
109
+ sgd_components = {}
110
+ with gr.Row():
111
+ for f in fields(hyper.sgd_params):
112
+ sgd_components[f.name] = gr.Number(
113
+ value=getattr(hyper.sgd_params, f.name),
114
+ label=f.name.replace("_", " ").title(),
115
+ interactive=True,
116
+ )
117
+
118
+ with gr.Group(visible=(hyper.optimizer == "Adam")) as adam_box:
119
+ adam_components = {}
120
+ with gr.Row():
121
+ for f in fields(hyper.adam_params):
122
+ adam_components[f.name] = gr.Number(
123
+ value=getattr(hyper.adam_params, f.name),
124
+ label=f.name.replace("_", " ").title(),
125
+ interactive=True,
126
+ )
127
+
128
+ optimizer_select.change(
129
+ fn=self.update_optimizer_type,
130
+ inputs=[state, optimizer_select],
131
+ outputs=[state, sgd_box, adam_box],
132
+ )
133
+
134
+ for name, component in sgd_components.items():
135
+ component.submit(
136
+ fn=self.update_sgd_hyperparameters,
137
+ inputs=[
138
+ state,
139
+ sgd_components["learning_rate"],
140
+ sgd_components["momentum"],
141
+ sgd_components["weight_decay"],
142
+ sgd_components["batch_size"],
143
+ ],
144
+ outputs=[state],
145
+ )
146
+
147
+ for name, component in adam_components.items():
148
+ component.submit(
149
+ fn=self.update_adam_hyperparameters,
150
+ inputs=[
151
+ state,
152
+ adam_components["learning_rate"],
153
+ adam_components["beta1"],
154
+ adam_components["beta2"],
155
+ adam_components["weight_decay"],
156
+ adam_components["batch_size"],
157
+ ],
158
+ outputs=[state],
159
+ )
160
+
161
+
mlp_visualizer.py CHANGED
@@ -1,5 +1,5 @@
1
  from collections import deque
2
- from dataclasses import dataclass, replace
3
  import functools
4
  from pathlib import Path
5
  import pickle
@@ -31,6 +31,68 @@ logger = logging.getLogger("ELVIS")
31
 
32
  from architecture import Architecture, ArchitectureView
33
  from dataset import Dataset, DatasetView, get_function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  class MlpVisualizer:
@@ -45,7 +107,7 @@ class MlpVisualizer:
45
  display: none;
46
  }"""
47
 
48
- def plot(self, dataset: Dataset, architecture: Architecture) -> Image.Image:
49
  print("Plotting")
50
  t1 = time.time()
51
  fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
@@ -57,7 +119,7 @@ class MlpVisualizer:
57
  if dataset.mode == "generate":
58
  x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100)
59
 
60
- # y_pred = self.model(torch.from_numpy(x_test).float()).detach().numpy()
61
 
62
  # plot
63
  fig, ax = plt.subplots(figsize=(8, 8))
@@ -76,7 +138,7 @@ class MlpVisualizer:
76
  if dataset.mode == "generate":
77
  plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
78
 
79
- if False:
80
  plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2))
81
 
82
  plt.legend()
@@ -92,6 +154,77 @@ class MlpVisualizer:
92
 
93
  return img
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def launch(self):
96
  # build the Gradio interface
97
  with gr.Blocks(css=self.css) as demo:
@@ -101,12 +234,17 @@ class MlpVisualizer:
101
  # states
102
  dataset = gr.State(Dataset())
103
  architecture = gr.State(Architecture())
 
 
 
 
 
104
 
105
  # GUI elements and layout
106
  with gr.Row():
107
  with gr.Column(scale=2):
108
  canvas = gr.Image(
109
- value=self.plot(dataset.value, architecture.value),
110
  show_download_button=False,
111
  container=True,
112
  )
@@ -116,22 +254,37 @@ class MlpVisualizer:
116
  dataset_view = DatasetView()
117
  dataset_view.build(state=dataset)
118
  dataset.change(
119
- fn=self.plot,
120
- inputs=[dataset],
121
- outputs=[canvas],
122
  )
123
 
124
  with gr.Tab("Architecture"):
125
  architecture_view = ArchitectureView()
126
  architecture_view.build(state=architecture)
127
  architecture.change(
128
- fn=self.plot,
129
- inputs=[dataset, architecture],
130
- outputs=[canvas],
131
  )
132
 
133
  with gr.Tab("Train"):
134
- gr.Markdown("HI")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  with gr.Tab("Plot"):
136
  gr.Markdown("HI")
137
  with gr.Tab("Export"):
 
1
  from collections import deque
2
+ from dataclasses import dataclass, fields
3
  import functools
4
  from pathlib import Path
5
  import pickle
 
31
 
32
  from architecture import Architecture, ArchitectureView
33
  from dataset import Dataset, DatasetView, get_function
34
+ from hyperparameters import Hyperparameters, HyperparametersView
35
+
36
+
37
+ @dataclass
38
+ class TrainState:
39
+ model: nn.Module
40
+ optimizer: torch.optim.Optimizer
41
+
42
+
43
+ def init_model(architecture: Architecture) -> nn.Module:
44
+ input_size = 1
45
+ output_size = 1
46
+
47
+ layers = []
48
+ for hidden_units, activation in zip(architecture.hidden_units, architecture.activations):
49
+ layers.append(nn.Linear(input_size, hidden_units))
50
+
51
+ if activation == "ReLU":
52
+ layers.append(nn.ReLU())
53
+ elif activation == "Sigmoid":
54
+ layers.append(nn.Sigmoid())
55
+ elif activation == "Tanh":
56
+ layers.append(nn.Tanh())
57
+ elif activation == "LeakyReLU":
58
+ layers.append(nn.LeakyReLU())
59
+ elif activation == "ELU":
60
+ layers.append(nn.ELU())
61
+ elif activation == "GELU":
62
+ layers.append(nn.GELU())
63
+ elif activation == "Identity":
64
+ layers.append(nn.Identity())
65
+ else:
66
+ raise ValueError(f"Unknown activation: {activation}")
67
+
68
+ input_size = hidden_units
69
+
70
+ layers.append(nn.Linear(input_size, output_size))
71
+ model = nn.Sequential(*layers)
72
+ return model
73
+
74
+
75
+ def init_optimizer(
76
+ model: nn.Module,
77
+ hyperparameters: Hyperparameters,
78
+ ) -> torch.optim.Optimizer:
79
+ if hyperparameters.optimizer == "SGD":
80
+ opt = torch.optim.SGD(
81
+ model.parameters(),
82
+ lr=hyperparameters.sgd_params.learning_rate,
83
+ momentum=hyperparameters.sgd_params.momentum,
84
+ weight_decay=hyperparameters.sgd_params.weight_decay,
85
+ )
86
+ elif hyperparameters.optimizer == "Adam":
87
+ opt = torch.optim.Adam(
88
+ model.parameters(),
89
+ lr=hyperparameters.adam_params.learning_rate,
90
+ betas=(hyperparameters.adam_params.beta1, hyperparameters.adam_params.beta2),
91
+ weight_decay=hyperparameters.adam_params.weight_decay,
92
+ )
93
+ else:
94
+ raise ValueError(f"Unknown optimizer: {hyperparameters.optimizer}")
95
+ return opt
96
 
97
 
98
  class MlpVisualizer:
 
107
  display: none;
108
  }"""
109
 
110
+ def plot(self, dataset: Dataset, train_state: TrainState) -> Image.Image:
111
  print("Plotting")
112
  t1 = time.time()
113
  fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
 
119
  if dataset.mode == "generate":
120
  x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100)
121
 
122
+ y_pred = train_state.model(torch.from_numpy(x_test).float()).detach().numpy()
123
 
124
  # plot
125
  fig, ax = plt.subplots(figsize=(8, 8))
 
138
  if dataset.mode == "generate":
139
  plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
140
 
141
+ if True:
142
  plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2))
143
 
144
  plt.legend()
 
154
 
155
  return img
156
 
157
+ def update_dataset(
158
+ self,
159
+ dataset: Dataset,
160
+ architecture: Architecture,
161
+ hyperparameters: Hyperparameters,
162
+ ):
163
+ print("Updating dataset")
164
+ new_model = init_model(architecture)
165
+ new_optimizer = init_optimizer(new_model, hyperparameters)
166
+ new_train_state = TrainState(new_model, new_optimizer)
167
+ new_canvas = self.plot(dataset, new_train_state)
168
+ return new_canvas, new_train_state
169
+
170
+ def update_architecture(
171
+ self,
172
+ dataset: Dataset,
173
+ architecture: Architecture,
174
+ hyperparameters: Hyperparameters,
175
+ ):
176
+ print("Updating architecture")
177
+ new_model = init_model(architecture)
178
+ new_optimizer = init_optimizer(new_model, hyperparameters)
179
+ new_train_state = TrainState(new_model, new_optimizer)
180
+ new_canvas = self.plot(dataset, new_train_state)
181
+ return new_canvas, new_train_state
182
+
183
+ def update_hyperparameters(
184
+ self,
185
+ dataset: Dataset,
186
+ architecture: Architecture,
187
+ hyperparameters: Hyperparameters,
188
+ ):
189
+ print("Updating hyperparameters")
190
+ new_model = init_model(architecture)
191
+ new_optimizer = init_optimizer(new_model, hyperparameters)
192
+ new_train_state = TrainState(new_model, new_optimizer)
193
+ new_canvas = self.plot(dataset, new_train_state)
194
+ return new_canvas, new_train_state
195
+
196
+ def train_step(
197
+ self,
198
+ dataset: Dataset,
199
+ hyperparameters: Hyperparameters,
200
+ train_state: TrainState,
201
+ ):
202
+ print("Training step")
203
+ model = train_state.model
204
+ optimizer = train_state.optimizer
205
+ batch_size = hyperparameters.batch_size
206
+
207
+ model.train()
208
+ x_train = torch.from_numpy(dataset.x).float()
209
+ y_train = torch.from_numpy(dataset.y).float()
210
+
211
+ if batch_size < x_train.shape[0]:
212
+ indices = torch.randperm(x_train.shape[0])[:batch_size]
213
+ x_train = x_train[indices]
214
+ y_train = y_train[indices]
215
+
216
+ y_pred = model(x_train)
217
+ loss = nn.MSELoss()(y_pred.flatten(), y_train)
218
+ optimizer.zero_grad()
219
+ loss.backward()
220
+ optimizer.step()
221
+
222
+ print(f"Training loss: {loss.item():.4f}")
223
+
224
+ new_canvas = self.plot(dataset, train_state)
225
+
226
+ return new_canvas, train_state
227
+
228
  def launch(self):
229
  # build the Gradio interface
230
  with gr.Blocks(css=self.css) as demo:
 
234
  # states
235
  dataset = gr.State(Dataset())
236
  architecture = gr.State(Architecture())
237
+ hyperparameters = gr.State(Hyperparameters())
238
+
239
+ model = init_model(architecture.value)
240
+ optimizer = init_optimizer(model, hyperparameters.value)
241
+ train_state = gr.State(TrainState(model, optimizer))
242
 
243
  # GUI elements and layout
244
  with gr.Row():
245
  with gr.Column(scale=2):
246
  canvas = gr.Image(
247
+ value=self.plot(dataset.value, train_state.value),
248
  show_download_button=False,
249
  container=True,
250
  )
 
254
  dataset_view = DatasetView()
255
  dataset_view.build(state=dataset)
256
  dataset.change(
257
+ fn=self.update_dataset,
258
+ inputs=[dataset, architecture, hyperparameters],
259
+ outputs=[canvas, train_state],
260
  )
261
 
262
  with gr.Tab("Architecture"):
263
  architecture_view = ArchitectureView()
264
  architecture_view.build(state=architecture)
265
  architecture.change(
266
+ fn=self.update_architecture,
267
+ inputs=[dataset, architecture, hyperparameters],
268
+ outputs=[canvas, train_state],
269
  )
270
 
271
  with gr.Tab("Train"):
272
+ hyperparameters_view = HyperparametersView()
273
+ hyperparameters_view.build(state=hyperparameters)
274
+ hyperparameters.change(
275
+ fn=self.update_hyperparameters,
276
+ inputs=[dataset, architecture, hyperparameters],
277
+ outputs=[canvas, train_state],
278
+ )
279
+
280
+ train_button = gr.Button("Train 1 step")
281
+ train_button.click(
282
+ fn=self.train_step,
283
+ inputs=[dataset, hyperparameters, train_state],
284
+ outputs=[canvas, train_state],
285
+ )
286
+
287
+
288
  with gr.Tab("Plot"):
289
  gr.Markdown("HI")
290
  with gr.Tab("Export"):