Joel Woodfield commited on
Commit
a13fdc8
·
1 Parent(s): 4c59025

Display current training loss

Browse files
Files changed (1) hide show
  1. mlp_visualizer.py +47 -29
mlp_visualizer.py CHANGED
@@ -195,10 +195,10 @@ class MlpVisualizer:
195
  # do not initialise here, otherwise gradio will make it not work
196
  # self.param_components = {}
197
 
198
- self.model, self.optimizer = self.init_model()
 
199
  self.num_steps_trained = 0
200
 
201
- self.criterion = nn.MSELoss()
202
 
203
  self.plot_options = {
204
  "show_training_data": True,
@@ -272,7 +272,15 @@ class MlpVisualizer:
272
 
273
  self.num_steps_trained = 0
274
 
275
- return model, optimizer
 
 
 
 
 
 
 
 
276
 
277
  def plot(self):
278
  '''
@@ -321,11 +329,11 @@ class MlpVisualizer:
321
  self.data_options["seed"] += 1
322
  self.x_train, self.y_train = self.generate_data()
323
  self.reset_model()
324
- return self.plot(), self.num_steps_trained
325
 
326
  def reset_model(self):
327
- self.model, self.optimizer = self.init_model()
328
- return self.plot(), self.num_steps_trained
329
 
330
  def update_data_options(self, **kwargs):
331
  for key, value in kwargs.items():
@@ -347,9 +355,9 @@ class MlpVisualizer:
347
 
348
  if "nsample" in kwargs:
349
  slider_update = gr.update(maximum=self.x_train.shape[0], value=min(self.basic_train_hparams["batch_size"], self.x_train.shape[0]))
350
- return self.plot(), slider_update, self.num_steps_trained
351
 
352
- return self.plot(), self.num_steps_trained
353
 
354
  def update_plot_options(self, **kwargs):
355
  for key, value in kwargs.items():
@@ -362,9 +370,9 @@ class MlpVisualizer:
362
  self.architecture_options["activations"] = activations
363
 
364
  # reset model
365
- self.model, self.optimizer = self.init_model()
366
 
367
- return self.plot(), self.num_steps_trained
368
 
369
  def update_basic_train_hparams(self, **kwargs):
370
  for key, value in kwargs.items():
@@ -372,9 +380,9 @@ class MlpVisualizer:
372
  self.basic_train_hparams[key] = value
373
 
374
  # reset model
375
- self.model, self.optimizer = self.init_model()
376
 
377
- return self.plot(), self.num_steps_trained
378
 
379
  def update_optimizer(self, optimizer_name):
380
  self.basic_train_hparams["optimizer"] = optimizer_name
@@ -388,9 +396,9 @@ class MlpVisualizer:
388
  updates.append(gr.update(visible=is_visible))
389
 
390
  # reset model
391
- self.model, self.optimizer = self.init_model()
392
 
393
- return updates + [self.plot(), self.num_steps_trained]
394
 
395
  def build_optimizer_components(self):
396
  self.param_components = {}
@@ -414,8 +422,8 @@ class MlpVisualizer:
414
  self.optimizer_hparams[optimizer_name][param_name] = value
415
 
416
  # reset model and plot
417
- self.model, self.optimizer = self.init_model()
418
- return self.plot(), self.num_steps_trained
419
 
420
  def train_step(self):
421
  self.model.train()
@@ -429,10 +437,15 @@ class MlpVisualizer:
429
  loss.backward()
430
  self.optimizer.step()
431
 
432
- print(f"Training loss: {loss.item():.4f}")
433
  self.num_steps_trained += 1
434
 
435
- return self.plot(), self.num_steps_trained
 
 
 
 
 
 
436
 
437
  def launch(self):
438
  # build the Gradio interface
@@ -525,6 +538,11 @@ class MlpVisualizer:
525
  value=0,
526
  interactive=False,
527
  )
 
 
 
 
 
528
 
529
  train_button = gr.Button("Train Step")
530
  reset_model_button = gr.Button("Reset Model")
@@ -560,59 +578,59 @@ class MlpVisualizer:
560
  function_box.submit(
561
  fn=lambda function: self.update_data_options(function=function),
562
  inputs=function_box,
563
- outputs=[self.canvas, train_step_counter],
564
  )
565
  x_min.submit(
566
  fn=lambda xmin: self.update_data_options(x_min=xmin),
567
  inputs=x_min,
568
- outputs=[self.canvas, train_step_counter],
569
  )
570
  x_max.submit(
571
  fn=lambda xmax: self.update_data_options(x_max=xmax),
572
  inputs=x_max,
573
- outputs=[self.canvas, train_step_counter],
574
  )
575
  num_points_slider.change(
576
  fn=lambda nsample: self.update_data_options(nsample=nsample),
577
  inputs=num_points_slider,
578
- outputs=[self.canvas, batch_size_slider, train_step_counter],
579
  )
580
  noise_value.submit(
581
  fn=lambda sigma: self.update_data_options(sigma=sigma),
582
  inputs=noise_value,
583
- outputs=[self.canvas, train_step_counter],
584
  )
585
  regenerate_button.click(
586
  fn=self._update_data_seed,
587
- outputs=[self.canvas, train_step_counter],
588
  )
589
 
590
  # train options
591
  optimizer_radio.change(
592
  fn=self.update_optimizer,
593
  inputs=optimizer_radio,
594
- outputs=[*all_param_components, self.canvas, train_step_counter],
595
  )
596
  batch_size_slider.change(
597
  fn=lambda batch_size: self.update_basic_train_hparams(batch_size=batch_size),
598
  inputs=batch_size_slider,
599
- outputs=[self.canvas, train_step_counter],
600
  )
601
  train_button.click(
602
  fn=self.train_step,
603
- outputs=[self.canvas, train_step_counter],
604
  show_progress="hidden",
605
  )
606
  reset_model_button.click(
607
  fn=self.reset_model,
608
- outputs=[self.canvas, train_step_counter],
609
  )
610
  for opt_name, params in self.param_components.items():
611
  for param_name, comp in params.items():
612
  comp.submit(
613
  fn=functools.partial(self.update_hparam, optimizer_name=opt_name, param_name=param_name),
614
  inputs=[comp],
615
- outputs=[self.canvas, train_step_counter],
616
  )
617
 
618
  # plot options
 
195
  # do not initialise here, otherwise gradio will make it not work
196
  # self.param_components = {}
197
 
198
+ self.criterion = nn.MSELoss()
199
+ self.model, self.optimizer, self.train_loss = self.init_model()
200
  self.num_steps_trained = 0
201
 
 
202
 
203
  self.plot_options = {
204
  "show_training_data": True,
 
272
 
273
  self.num_steps_trained = 0
274
 
275
+ # compute initial train loss
276
+ model.eval()
277
+ inputs = torch.from_numpy(self.x_train).float()
278
+ targets = torch.from_numpy(self.y_train).float().unsqueeze(1)
279
+ with torch.no_grad():
280
+ outputs = model(inputs)
281
+ train_loss = self.criterion(outputs, targets).item()
282
+
283
+ return model, optimizer, train_loss
284
 
285
  def plot(self):
286
  '''
 
329
  self.data_options["seed"] += 1
330
  self.x_train, self.y_train = self.generate_data()
331
  self.reset_model()
332
+ return self.plot(), self.num_steps_trained, self.train_loss
333
 
334
  def reset_model(self):
335
+ self.model, self.optimizer, self.train_loss = self.init_model()
336
+ return self.plot(), self.num_steps_trained, self.train_loss
337
 
338
  def update_data_options(self, **kwargs):
339
  for key, value in kwargs.items():
 
355
 
356
  if "nsample" in kwargs:
357
  slider_update = gr.update(maximum=self.x_train.shape[0], value=min(self.basic_train_hparams["batch_size"], self.x_train.shape[0]))
358
+ return self.plot(), slider_update, self.num_steps_trained, self.train_loss
359
 
360
+ return self.plot(), self.num_steps_trained, self.train_loss
361
 
362
  def update_plot_options(self, **kwargs):
363
  for key, value in kwargs.items():
 
370
  self.architecture_options["activations"] = activations
371
 
372
  # reset model
373
+ self.model, self.optimizer, self.train_loss = self.init_model()
374
 
375
+ return self.plot(), self.num_steps_trained, self.train_loss
376
 
377
  def update_basic_train_hparams(self, **kwargs):
378
  for key, value in kwargs.items():
 
380
  self.basic_train_hparams[key] = value
381
 
382
  # reset model
383
+ self.model, self.optimizer, self.train_loss = self.init_model()
384
 
385
+ return self.plot(), self.num_steps_trained, self.train_loss
386
 
387
  def update_optimizer(self, optimizer_name):
388
  self.basic_train_hparams["optimizer"] = optimizer_name
 
396
  updates.append(gr.update(visible=is_visible))
397
 
398
  # reset model
399
+ self.model, self.optimizer, self.train_loss = self.init_model()
400
 
401
+ return updates + [self.plot(), self.num_steps_trained, self.train_loss]
402
 
403
  def build_optimizer_components(self):
404
  self.param_components = {}
 
422
  self.optimizer_hparams[optimizer_name][param_name] = value
423
 
424
  # reset model and plot
425
+ self.model, self.optimizer, self.train_loss = self.init_model()
426
+ return self.plot(), self.num_steps_trained, self.train_loss
427
 
428
  def train_step(self):
429
  self.model.train()
 
437
  loss.backward()
438
  self.optimizer.step()
439
 
 
440
  self.num_steps_trained += 1
441
 
442
+ # update train loss
443
+ self.model.eval()
444
+ with torch.no_grad():
445
+ outputs = self.model(inputs)
446
+ self.train_loss = self.criterion(outputs, targets).item()
447
+
448
+ return self.plot(), self.num_steps_trained, self.train_loss
449
 
450
  def launch(self):
451
  # build the Gradio interface
 
538
  value=0,
539
  interactive=False,
540
  )
541
+ train_loss_display = gr.Number(
542
+ label="Train loss",
543
+ value=self.train_loss,
544
+ interactive=False,
545
+ )
546
 
547
  train_button = gr.Button("Train Step")
548
  reset_model_button = gr.Button("Reset Model")
 
578
  function_box.submit(
579
  fn=lambda function: self.update_data_options(function=function),
580
  inputs=function_box,
581
+ outputs=[self.canvas, train_step_counter, train_loss_display],
582
  )
583
  x_min.submit(
584
  fn=lambda xmin: self.update_data_options(x_min=xmin),
585
  inputs=x_min,
586
+ outputs=[self.canvas, train_step_counter, train_loss_display],
587
  )
588
  x_max.submit(
589
  fn=lambda xmax: self.update_data_options(x_max=xmax),
590
  inputs=x_max,
591
+ outputs=[self.canvas, train_step_counter, train_loss_display],
592
  )
593
  num_points_slider.change(
594
  fn=lambda nsample: self.update_data_options(nsample=nsample),
595
  inputs=num_points_slider,
596
+ outputs=[self.canvas, batch_size_slider, train_step_counter, train_loss_display],
597
  )
598
  noise_value.submit(
599
  fn=lambda sigma: self.update_data_options(sigma=sigma),
600
  inputs=noise_value,
601
+ outputs=[self.canvas, train_step_counter, train_loss_display],
602
  )
603
  regenerate_button.click(
604
  fn=self._update_data_seed,
605
+ outputs=[self.canvas, train_step_counter, train_loss_display],
606
  )
607
 
608
  # train options
609
  optimizer_radio.change(
610
  fn=self.update_optimizer,
611
  inputs=optimizer_radio,
612
+ outputs=[*all_param_components, self.canvas, train_step_counter, train_loss_display],
613
  )
614
  batch_size_slider.change(
615
  fn=lambda batch_size: self.update_basic_train_hparams(batch_size=batch_size),
616
  inputs=batch_size_slider,
617
+ outputs=[self.canvas, train_step_counter, train_loss_display],
618
  )
619
  train_button.click(
620
  fn=self.train_step,
621
+ outputs=[self.canvas, train_step_counter, train_loss_display],
622
  show_progress="hidden",
623
  )
624
  reset_model_button.click(
625
  fn=self.reset_model,
626
+ outputs=[self.canvas, train_step_counter, train_loss_display],
627
  )
628
  for opt_name, params in self.param_components.items():
629
  for param_name, comp in params.items():
630
  comp.submit(
631
  fn=functools.partial(self.update_hparam, optimizer_name=opt_name, param_name=param_name),
632
  inputs=[comp],
633
+ outputs=[self.canvas, train_step_counter, train_loss_display],
634
  )
635
 
636
  # plot options