Spaces:
Running
Running
Commit ·
1d54526
1
Parent(s): 2e354a0
Add option to change the embedder arguments
Browse files- decision_boundary.py +25 -5
- export_code_template.py.j2 +2 -1
decision_boundary.py
CHANGED
|
@@ -148,6 +148,7 @@ class InteractiveDecisionBoundary:
|
|
| 148 |
|
| 149 |
# data embedding and preprocessing values
|
| 150 |
self.embedder_class = CoordinateProjection
|
|
|
|
| 151 |
self.normalizer_class = None
|
| 152 |
self.jitter_std = 0
|
| 153 |
|
|
@@ -210,7 +211,7 @@ class InteractiveDecisionBoundary:
|
|
| 210 |
return features
|
| 211 |
|
| 212 |
def _embed_features(self, features, return_embedder=False):
|
| 213 |
-
embedder = self.embedder_class(n_components=2)
|
| 214 |
features = embedder.fit_transform(features)
|
| 215 |
|
| 216 |
if return_embedder:
|
|
@@ -222,6 +223,7 @@ class InteractiveDecisionBoundary:
|
|
| 222 |
self.normalizer_class = None
|
| 223 |
self.jitter_std = 0
|
| 224 |
self.embedder_class = CoordinateProjection
|
|
|
|
| 225 |
|
| 226 |
def plot(self, decision_boundary=False):
|
| 227 |
'''
|
|
@@ -475,6 +477,9 @@ class InteractiveDecisionBoundary:
|
|
| 475 |
else:
|
| 476 |
model_params_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in model_params.items()]) + "\n\t"
|
| 477 |
|
|
|
|
|
|
|
|
|
|
| 478 |
template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text())
|
| 479 |
variables = {
|
| 480 |
'model_import_statement': model_import_stmt,
|
|
@@ -482,6 +487,7 @@ class InteractiveDecisionBoundary:
|
|
| 482 |
"normalizer_import_statement": normalizer_import_stmt,
|
| 483 |
'dataset_file': self.DATASET_FILE,
|
| 484 |
'embedder_class': embedder_class,
|
|
|
|
| 485 |
'model_class': model_class,
|
| 486 |
'model_params': model_params_text,
|
| 487 |
'fig_width': self.canvas_width / 100,
|
|
@@ -512,6 +518,11 @@ class InteractiveDecisionBoundary:
|
|
| 512 |
print('updated Embedder:', self.embedder_class)
|
| 513 |
return self.plot()
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
def update_normalizer(self, normalizer):
|
| 516 |
self.normalizer_class = self.normalizers[normalizer]
|
| 517 |
print('updated Normalizer:', self.normalizer_class)
|
|
@@ -533,15 +544,15 @@ class InteractiveDecisionBoundary:
|
|
| 533 |
if type == 'Draw2D':
|
| 534 |
self.custom_selected = True
|
| 535 |
self.dataset_type = "Draw2D"
|
| 536 |
-
new_fields = gr.File(visible=False), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value=0), gr.Dropdown(visible=False, value="CoordinateProjection"), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=False)
|
| 537 |
elif type == 'Upload':
|
| 538 |
self.dataset_type = "Upload"
|
| 539 |
self.custom_selected = False
|
| 540 |
-
new_fields = gr.File(visible=True), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
|
| 541 |
elif type == 'sklearn':
|
| 542 |
self.dataset_type = "sklearn"
|
| 543 |
self.custom_selected = False
|
| 544 |
-
new_fields = gr.File(visible=False), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
|
| 545 |
else:
|
| 546 |
# TODO: better error handling
|
| 547 |
print('Error - unknown dataset type:', type)
|
|
@@ -659,6 +670,9 @@ class InteractiveDecisionBoundary:
|
|
| 659 |
)
|
| 660 |
self.embedder_selector = embedder_selector
|
| 661 |
|
|
|
|
|
|
|
|
|
|
| 662 |
# custom data
|
| 663 |
label = gr.Radio(["Red", "Green", "Blue"], value="Red", label="Choose point label", visible=True, elem_id="rowheight")
|
| 664 |
self.label = label
|
|
@@ -725,7 +739,7 @@ class InteractiveDecisionBoundary:
|
|
| 725 |
fn=self.handle_dataset_radio,
|
| 726 |
inputs=dataset_radio,
|
| 727 |
outputs=(
|
| 728 |
-
self.data_image, data_table, file_chooser, sklearn_data_selector, normalizer_selector, jittering_slider, embedder_selector, label, btn_clear, btn_save
|
| 729 |
),
|
| 730 |
)
|
| 731 |
|
|
@@ -747,6 +761,12 @@ class InteractiveDecisionBoundary:
|
|
| 747 |
inputs=embedder_selector,
|
| 748 |
outputs=self.data_image)
|
| 749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
normalizer_selector.change(
|
| 751 |
fn=self.update_normalizer,
|
| 752 |
inputs=normalizer_selector,
|
|
|
|
| 148 |
|
| 149 |
# data embedding and preprocessing values
|
| 150 |
self.embedder_class = CoordinateProjection
|
| 151 |
+
self.embedder_args = ""
|
| 152 |
self.normalizer_class = None
|
| 153 |
self.jitter_std = 0
|
| 154 |
|
|
|
|
| 211 |
return features
|
| 212 |
|
| 213 |
def _embed_features(self, features, return_embedder=False):
|
| 214 |
+
embedder = self.embedder_class(n_components=2, **parse_param_string(self.embedder_args))
|
| 215 |
features = embedder.fit_transform(features)
|
| 216 |
|
| 217 |
if return_embedder:
|
|
|
|
| 223 |
self.normalizer_class = None
|
| 224 |
self.jitter_std = 0
|
| 225 |
self.embedder_class = CoordinateProjection
|
| 226 |
+
self.embedder_args = ""
|
| 227 |
|
| 228 |
def plot(self, decision_boundary=False):
|
| 229 |
'''
|
|
|
|
| 477 |
else:
|
| 478 |
model_params_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in model_params.items()]) + "\n\t"
|
| 479 |
|
| 480 |
+
embedder_args = {"n_components": 2, **parse_param_string(self.embedder_args)}
|
| 481 |
+
embedder_args_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in embedder_args.items()]) + "\n\t"
|
| 482 |
+
|
| 483 |
template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text())
|
| 484 |
variables = {
|
| 485 |
'model_import_statement': model_import_stmt,
|
|
|
|
| 487 |
"normalizer_import_statement": normalizer_import_stmt,
|
| 488 |
'dataset_file': self.DATASET_FILE,
|
| 489 |
'embedder_class': embedder_class,
|
| 490 |
+
'embedder_args': embedder_args_text,
|
| 491 |
'model_class': model_class,
|
| 492 |
'model_params': model_params_text,
|
| 493 |
'fig_width': self.canvas_width / 100,
|
|
|
|
| 518 |
print('updated Embedder:', self.embedder_class)
|
| 519 |
return self.plot()
|
| 520 |
|
| 521 |
+
def update_embedder_args(self, embedder_args):
|
| 522 |
+
self.embedder_args = embedder_args
|
| 523 |
+
print('updated Embedder args:', self.embedder_args)
|
| 524 |
+
return self.plot()
|
| 525 |
+
|
| 526 |
def update_normalizer(self, normalizer):
|
| 527 |
self.normalizer_class = self.normalizers[normalizer]
|
| 528 |
print('updated Normalizer:', self.normalizer_class)
|
|
|
|
| 544 |
if type == 'Draw2D':
|
| 545 |
self.custom_selected = True
|
| 546 |
self.dataset_type = "Draw2D"
|
| 547 |
+
new_fields = gr.File(visible=False), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value=0), gr.Dropdown(visible=False, value="CoordinateProjection"), gr.Textbox(visible=False), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=False)
|
| 548 |
elif type == 'Upload':
|
| 549 |
self.dataset_type = "Upload"
|
| 550 |
self.custom_selected = False
|
| 551 |
+
new_fields = gr.File(visible=True), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
|
| 552 |
elif type == 'sklearn':
|
| 553 |
self.dataset_type = "sklearn"
|
| 554 |
self.custom_selected = False
|
| 555 |
+
new_fields = gr.File(visible=False), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
|
| 556 |
else:
|
| 557 |
# TODO: better error handling
|
| 558 |
print('Error - unknown dataset type:', type)
|
|
|
|
| 670 |
)
|
| 671 |
self.embedder_selector = embedder_selector
|
| 672 |
|
| 673 |
+
embedder_args_textbox = gr.Textbox(label="Embedder arguments", visible=False)
|
| 674 |
+
self.embedder_args_textbox = embedder_args_textbox
|
| 675 |
+
|
| 676 |
# custom data
|
| 677 |
label = gr.Radio(["Red", "Green", "Blue"], value="Red", label="Choose point label", visible=True, elem_id="rowheight")
|
| 678 |
self.label = label
|
|
|
|
| 739 |
fn=self.handle_dataset_radio,
|
| 740 |
inputs=dataset_radio,
|
| 741 |
outputs=(
|
| 742 |
+
self.data_image, data_table, file_chooser, sklearn_data_selector, normalizer_selector, jittering_slider, embedder_selector, embedder_args_textbox, label, btn_clear, btn_save
|
| 743 |
),
|
| 744 |
)
|
| 745 |
|
|
|
|
| 761 |
inputs=embedder_selector,
|
| 762 |
outputs=self.data_image)
|
| 763 |
|
| 764 |
+
embedder_args_textbox.change(
|
| 765 |
+
fn=self.update_embedder_args,
|
| 766 |
+
inputs=embedder_args_textbox,
|
| 767 |
+
outputs=self.data_image,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
normalizer_selector.change(
|
| 771 |
fn=self.update_normalizer,
|
| 772 |
inputs=normalizer_selector,
|
export_code_template.py.j2
CHANGED
|
@@ -84,7 +84,8 @@ def main():
|
|
| 84 |
# if you want to load a model
|
| 85 |
# model = load_model("model.pkl")
|
| 86 |
|
| 87 |
-
embedder = {{ embedder_class }}(
|
|
|
|
| 88 |
create_plot(X, y, model, embedder)
|
| 89 |
plt.show()
|
| 90 |
# if you want to save as image
|
|
|
|
| 84 |
# if you want to load a model
|
| 85 |
# model = load_model("model.pkl")
|
| 86 |
|
| 87 |
+
embedder = {{ embedder_class }}({{ embedder_args }})
|
| 88 |
+
embedder.fit(X)
|
| 89 |
create_plot(X, y, model, embedder)
|
| 90 |
plt.show()
|
| 91 |
# if you want to save as image
|