joel-woodfield commited on
Commit
1d54526
·
1 Parent(s): 2e354a0

Add option to change the embedder arguments

Browse files
Files changed (2) hide show
  1. decision_boundary.py +25 -5
  2. export_code_template.py.j2 +2 -1
decision_boundary.py CHANGED
@@ -148,6 +148,7 @@ class InteractiveDecisionBoundary:
148
 
149
  # data embedding and preprocessing values
150
  self.embedder_class = CoordinateProjection
 
151
  self.normalizer_class = None
152
  self.jitter_std = 0
153
 
@@ -210,7 +211,7 @@ class InteractiveDecisionBoundary:
210
  return features
211
 
212
  def _embed_features(self, features, return_embedder=False):
213
- embedder = self.embedder_class(n_components=2)
214
  features = embedder.fit_transform(features)
215
 
216
  if return_embedder:
@@ -222,6 +223,7 @@ class InteractiveDecisionBoundary:
222
  self.normalizer_class = None
223
  self.jitter_std = 0
224
  self.embedder_class = CoordinateProjection
 
225
 
226
  def plot(self, decision_boundary=False):
227
  '''
@@ -475,6 +477,9 @@ class InteractiveDecisionBoundary:
475
  else:
476
  model_params_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in model_params.items()]) + "\n\t"
477
 
 
 
 
478
  template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text())
479
  variables = {
480
  'model_import_statement': model_import_stmt,
@@ -482,6 +487,7 @@ class InteractiveDecisionBoundary:
482
  "normalizer_import_statement": normalizer_import_stmt,
483
  'dataset_file': self.DATASET_FILE,
484
  'embedder_class': embedder_class,
 
485
  'model_class': model_class,
486
  'model_params': model_params_text,
487
  'fig_width': self.canvas_width / 100,
@@ -512,6 +518,11 @@ class InteractiveDecisionBoundary:
512
  print('updated Embedder:', self.embedder_class)
513
  return self.plot()
514
 
 
 
 
 
 
515
  def update_normalizer(self, normalizer):
516
  self.normalizer_class = self.normalizers[normalizer]
517
  print('updated Normalizer:', self.normalizer_class)
@@ -533,15 +544,15 @@ class InteractiveDecisionBoundary:
533
  if type == 'Draw2D':
534
  self.custom_selected = True
535
  self.dataset_type = "Draw2D"
536
- new_fields = gr.File(visible=False), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value=0), gr.Dropdown(visible=False, value="CoordinateProjection"), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=False)
537
  elif type == 'Upload':
538
  self.dataset_type = "Upload"
539
  self.custom_selected = False
540
- new_fields = gr.File(visible=True), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
541
  elif type == 'sklearn':
542
  self.dataset_type = "sklearn"
543
  self.custom_selected = False
544
- new_fields = gr.File(visible=False), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
545
  else:
546
  # TODO: better error handling
547
  print('Error - unknown dataset type:', type)
@@ -659,6 +670,9 @@ class InteractiveDecisionBoundary:
659
  )
660
  self.embedder_selector = embedder_selector
661
 
 
 
 
662
  # custom data
663
  label = gr.Radio(["Red", "Green", "Blue"], value="Red", label="Choose point label", visible=True, elem_id="rowheight")
664
  self.label = label
@@ -725,7 +739,7 @@ class InteractiveDecisionBoundary:
725
  fn=self.handle_dataset_radio,
726
  inputs=dataset_radio,
727
  outputs=(
728
- self.data_image, data_table, file_chooser, sklearn_data_selector, normalizer_selector, jittering_slider, embedder_selector, label, btn_clear, btn_save
729
  ),
730
  )
731
 
@@ -747,6 +761,12 @@ class InteractiveDecisionBoundary:
747
  inputs=embedder_selector,
748
  outputs=self.data_image)
749
 
 
 
 
 
 
 
750
  normalizer_selector.change(
751
  fn=self.update_normalizer,
752
  inputs=normalizer_selector,
 
148
 
149
  # data embedding and preprocessing values
150
  self.embedder_class = CoordinateProjection
151
+ self.embedder_args = ""
152
  self.normalizer_class = None
153
  self.jitter_std = 0
154
 
 
211
  return features
212
 
213
  def _embed_features(self, features, return_embedder=False):
214
+ embedder = self.embedder_class(n_components=2, **parse_param_string(self.embedder_args))
215
  features = embedder.fit_transform(features)
216
 
217
  if return_embedder:
 
223
  self.normalizer_class = None
224
  self.jitter_std = 0
225
  self.embedder_class = CoordinateProjection
226
+ self.embedder_args = ""
227
 
228
  def plot(self, decision_boundary=False):
229
  '''
 
477
  else:
478
  model_params_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in model_params.items()]) + "\n\t"
479
 
480
+ embedder_args = {"n_components": 2, **parse_param_string(self.embedder_args)}
481
+ embedder_args_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in embedder_args.items()]) + "\n\t"
482
+
483
  template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text())
484
  variables = {
485
  'model_import_statement': model_import_stmt,
 
487
  "normalizer_import_statement": normalizer_import_stmt,
488
  'dataset_file': self.DATASET_FILE,
489
  'embedder_class': embedder_class,
490
+ 'embedder_args': embedder_args_text,
491
  'model_class': model_class,
492
  'model_params': model_params_text,
493
  'fig_width': self.canvas_width / 100,
 
518
  print('updated Embedder:', self.embedder_class)
519
  return self.plot()
520
 
521
+ def update_embedder_args(self, embedder_args):
522
+ self.embedder_args = embedder_args
523
+ print('updated Embedder args:', self.embedder_args)
524
+ return self.plot()
525
+
526
  def update_normalizer(self, normalizer):
527
  self.normalizer_class = self.normalizers[normalizer]
528
  print('updated Normalizer:', self.normalizer_class)
 
544
  if type == 'Draw2D':
545
  self.custom_selected = True
546
  self.dataset_type = "Draw2D"
547
+ new_fields = gr.File(visible=False), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value=0), gr.Dropdown(visible=False, value="CoordinateProjection"), gr.Textbox(visible=False), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=False)
548
  elif type == 'Upload':
549
  self.dataset_type = "Upload"
550
  self.custom_selected = False
551
+ new_fields = gr.File(visible=True), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
552
  elif type == 'sklearn':
553
  self.dataset_type = "sklearn"
554
  self.custom_selected = False
555
+ new_fields = gr.File(visible=False), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
556
  else:
557
  # TODO: better error handling
558
  print('Error - unknown dataset type:', type)
 
670
  )
671
  self.embedder_selector = embedder_selector
672
 
673
+ embedder_args_textbox = gr.Textbox(label="Embedder arguments", visible=False)
674
+ self.embedder_args_textbox = embedder_args_textbox
675
+
676
  # custom data
677
  label = gr.Radio(["Red", "Green", "Blue"], value="Red", label="Choose point label", visible=True, elem_id="rowheight")
678
  self.label = label
 
739
  fn=self.handle_dataset_radio,
740
  inputs=dataset_radio,
741
  outputs=(
742
+ self.data_image, data_table, file_chooser, sklearn_data_selector, normalizer_selector, jittering_slider, embedder_selector, embedder_args_textbox, label, btn_clear, btn_save
743
  ),
744
  )
745
 
 
761
  inputs=embedder_selector,
762
  outputs=self.data_image)
763
 
764
+ embedder_args_textbox.change(
765
+ fn=self.update_embedder_args,
766
+ inputs=embedder_args_textbox,
767
+ outputs=self.data_image,
768
+ )
769
+
770
  normalizer_selector.change(
771
  fn=self.update_normalizer,
772
  inputs=normalizer_selector,
export_code_template.py.j2 CHANGED
@@ -84,7 +84,8 @@ def main():
84
  # if you want to load a model
85
  # model = load_model("model.pkl")
86
 
87
- embedder = {{ embedder_class }}(n_components=2).fit(X)
 
88
  create_plot(X, y, model, embedder)
89
  plt.show()
90
  # if you want to save as image
 
84
  # if you want to load a model
85
  # model = load_model("model.pkl")
86
 
87
+ embedder = {{ embedder_class }}({{ embedder_args }})
88
+ embedder.fit(X)
89
  create_plot(X, y, model, embedder)
90
  plt.show()
91
  # if you want to save as image