Spaces:
Build error
Build error
bug fix
Browse files
app.py
CHANGED
|
@@ -483,18 +483,29 @@ class ExplainerCheckbox(Component):
|
|
| 483 |
|
| 484 |
data_id = self.gallery.selected_index
|
| 485 |
|
| 486 |
-
|
| 487 |
-
|
| 488 |
explainer_id=self.default_exp_id,
|
| 489 |
metric_id=self.obj_metric,
|
| 490 |
direction='maximize',
|
| 491 |
sampler=SAMPLE_METHOD,
|
| 492 |
n_trials=OPT_N_TRIALS,
|
| 493 |
)
|
|
|
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
| 499 |
self.optimal_exp_id = opt_explainer_id
|
| 500 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
|
@@ -628,6 +639,8 @@ from torch.utils.data import DataLoader
|
|
| 628 |
from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
|
| 629 |
|
| 630 |
os.environ['GRADIO_TEMP_DIR'] = '.tmp'
|
|
|
|
|
|
|
| 631 |
|
| 632 |
def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
|
| 633 |
|
|
@@ -637,7 +650,7 @@ model, transform = get_torchvision_model('resnet18')
|
|
| 637 |
dataset = get_imagenet_dataset(transform)
|
| 638 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
| 639 |
experiment1 = AutoExplanationForImageClassification(
|
| 640 |
-
model=model,
|
| 641 |
data=loader,
|
| 642 |
input_extractor=lambda batch: batch[0],
|
| 643 |
label_extractor=lambda batch: batch[-1],
|
|
@@ -657,7 +670,7 @@ model, transform = get_torchvision_model('vit_b_16')
|
|
| 657 |
dataset = get_imagenet_dataset(transform)
|
| 658 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
| 659 |
experiment2 = AutoExplanationForImageClassification(
|
| 660 |
-
model=model,
|
| 661 |
data=loader,
|
| 662 |
input_extractor=lambda batch: batch[0],
|
| 663 |
label_extractor=lambda batch: batch[-1],
|
|
|
|
| 483 |
|
| 484 |
data_id = self.gallery.selected_index
|
| 485 |
|
| 486 |
+
opt_output = self.experiment.optimize(
|
| 487 |
+
data_ids=data_id.value,
|
| 488 |
explainer_id=self.default_exp_id,
|
| 489 |
metric_id=self.obj_metric,
|
| 490 |
direction='maximize',
|
| 491 |
sampler=SAMPLE_METHOD,
|
| 492 |
n_trials=OPT_N_TRIALS,
|
| 493 |
)
|
| 494 |
+
|
| 495 |
|
| 496 |
+
def get_str_ppid(pp_obj):
|
| 497 |
+
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
|
| 498 |
+
|
| 499 |
+
str_id = get_str_ppid(opt_output.postprocessor)
|
| 500 |
+
for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()):
|
| 501 |
+
if get_str_ppid(pp_obj) == str_id:
|
| 502 |
+
opt_postprocessor_id = pp_id
|
| 503 |
+
break
|
| 504 |
+
|
| 505 |
+
opt_explainer_id = max([x['id'] for x in self.groups.info]) + 1
|
| 506 |
+
opt_output.explainer.model = self.experiment.model
|
| 507 |
+
self.experiment.manager._explainers.append(opt_output.explainer)
|
| 508 |
+
self.experiment.manager._explainer_ids.append(opt_explainer_id)
|
| 509 |
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
| 510 |
self.optimal_exp_id = opt_explainer_id
|
| 511 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
|
|
|
| 639 |
from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
|
| 640 |
|
| 641 |
os.environ['GRADIO_TEMP_DIR'] = '.tmp'
|
| 642 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 643 |
+
device = torch.device("cpu")
|
| 644 |
|
| 645 |
def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
|
| 646 |
|
|
|
|
| 650 |
dataset = get_imagenet_dataset(transform)
|
| 651 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
| 652 |
experiment1 = AutoExplanationForImageClassification(
|
| 653 |
+
model=model.to(device),
|
| 654 |
data=loader,
|
| 655 |
input_extractor=lambda batch: batch[0],
|
| 656 |
label_extractor=lambda batch: batch[-1],
|
|
|
|
| 670 |
dataset = get_imagenet_dataset(transform)
|
| 671 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
| 672 |
experiment2 = AutoExplanationForImageClassification(
|
| 673 |
+
model=model.to(device),
|
| 674 |
data=loader,
|
| 675 |
input_extractor=lambda batch: batch[0],
|
| 676 |
label_extractor=lambda batch: batch[-1],
|