joel-woodfield commited on
Commit
ba3e474
·
1 Parent(s): 19d9a90

Distinguish loss and regularization contours and add legend

Browse files
Files changed (1) hide show
  1. regularization.py +7 -1
regularization.py CHANGED
@@ -4,6 +4,7 @@ import pickle
4
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
 
7
  import inspect
8
  import numpy as np
9
  import pandas as pd
@@ -175,13 +176,18 @@ class Regularization:
175
  colors = [cmap(i / (N - 1)) for i in range(N)]
176
 
177
  # regularizer contours
178
- cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors)
179
  ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
180
 
181
  # loss contours
182
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
183
  ax.clabel(cs2, inline=True, fontsize=8)
184
 
 
 
 
 
 
185
  # plot solutions
186
  #for alpha, w, norm, mse in solutions:
187
  #ax.plot(w[0], w[1], "ro")
 
4
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
+ import matplotlib.lines as mlines
8
  import inspect
9
  import numpy as np
10
  import pandas as pd
 
176
  colors = [cmap(i / (N - 1)) for i in range(N)]
177
 
178
  # regularizer contours
179
+ cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors, linestyles="dashed")
180
  ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
181
 
182
  # loss contours
183
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
184
  ax.clabel(cs2, inline=True, fontsize=8)
185
 
186
+ # custom legend
187
+ loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
188
+ reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
189
+ ax.legend(handles=[loss_line, reg_line])
190
+
191
  # plot solutions
192
  #for alpha, w, norm, mse in solutions:
193
  #ax.plot(w[0], w[1], "ro")