Joel Woodfield commited on
Commit
5d4093b
·
1 Parent(s): 9e5bc2a

Fix bug with hiding regularization path

Browse files
Files changed (1) hide show
  1. regularization.py +18 -14
regularization.py CHANGED
@@ -133,19 +133,22 @@ class Regularization:
133
  if loss_type == "l2":
134
  path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type)
135
  else:
136
- min_loss_reg = reg_values.ravel()[np.argmin(losses)]
137
- path_reg_levels = np.linspace(0, min_loss_reg, 20)
138
- path_w = []
139
- for reg_level in path_reg_levels:
140
- mask = reg_values <= reg_level
141
- if np.sum(mask) == 0:
142
- continue
143
- idx = np.argmin(losses[mask])
144
- path_w.append(
145
- np.stack((W1, W2), axis=-1)[mask][idx]
146
- )
147
-
148
- path_w = np.array(path_w)
 
 
 
149
  else:
150
  path_w = None
151
 
@@ -289,7 +292,8 @@ class Regularization:
289
  return loss_type
290
 
291
  def update_reg_path_visibility(self, loss_type: str):
292
- return gr.update(visible=(loss_type == "l2"))
 
293
 
294
  def update_regularizer(self, reg_type: str):
295
  if reg_type not in self.REGULARIZER_TYPES:
 
133
  if loss_type == "l2":
134
  path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type)
135
  else:
136
+ # one possible way that works but its rough
137
+ # min_loss_reg = reg_values.ravel()[np.argmin(losses)]
138
+ # path_reg_levels = np.linspace(0, min_loss_reg, 20)
139
+ # path_w = []
140
+ # for reg_level in path_reg_levels:
141
+ # mask = reg_values <= reg_level
142
+ # if np.sum(mask) == 0:
143
+ # continue
144
+ # idx = np.argmin(losses[mask])
145
+ # path_w.append(
146
+ # np.stack((W1, W2), axis=-1)[mask][idx]
147
+ # )
148
+ #
149
+ # path_w = np.array(path_w)
150
+
151
+ path_w = None
152
  else:
153
  path_w = None
154
 
 
292
  return loss_type
293
 
294
  def update_reg_path_visibility(self, loss_type: str):
295
+ visible = loss_type == "l2"
296
+ return gr.update(visible=visible)
297
 
298
  def update_regularizer(self, reg_type: str):
299
  if reg_type not in self.REGULARIZER_TYPES: