Spaces:
Build error
Build error
reconstruc baseline, mask
Browse files
app.py
CHANGED
|
@@ -7,6 +7,8 @@ import gradio as gr
|
|
| 7 |
import spaces
|
| 8 |
from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
|
| 9 |
from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
|
|
|
|
|
|
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import plotly.graph_objects as go
|
| 12 |
import plotly.express as px
|
|
@@ -541,10 +543,9 @@ class ExplainerCheckbox(Component):
|
|
| 541 |
break
|
| 542 |
|
| 543 |
opt_exp_id = max([x['id'] for x in checkbox_group_info]) + 1
|
| 544 |
-
# opt_output.explainer.model = self.experiment.model
|
| 545 |
-
# self.experiment.manager._explainers.append(opt_output.explainer)
|
| 546 |
-
# self.experiment.manager._explainer_ids.append(opt_exp_id)
|
| 547 |
|
|
|
|
|
|
|
| 548 |
opt_res = {
|
| 549 |
'id': opt_exp_id,
|
| 550 |
'class': opt_output.explainer.__class__,
|
|
@@ -558,15 +559,56 @@ class ExplainerCheckbox(Component):
|
|
| 558 |
return [opt_res, checkbox_group_info, checkbox, bttn]
|
| 559 |
|
| 560 |
def update_exp(exp_res):
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
self.experiment.manager._explainers.append(explainer)
|
| 571 |
self.experiment.manager._explainer_ids.append(_id)
|
| 572 |
|
|
|
|
| 7 |
import spaces
|
| 8 |
from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
|
| 9 |
from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
|
| 10 |
+
from pnpxai.explainers.utils.baselines import BASELINE_FUNCTIONS_FOR_IMAGE
|
| 11 |
+
from pnpxai.explainers.utils.feature_masks import FEATURE_MASK_FUNCTIONS_FOR_IMAGE
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
import plotly.graph_objects as go
|
| 14 |
import plotly.express as px
|
|
|
|
| 543 |
break
|
| 544 |
|
| 545 |
opt_exp_id = max([x['id'] for x in checkbox_group_info]) + 1
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
+
# Deliver the parameter and class and reconstruct
|
| 548 |
+
# It should be done because spaces.GPU cannot pickle the class object
|
| 549 |
opt_res = {
|
| 550 |
'id': opt_exp_id,
|
| 551 |
'class': opt_output.explainer.__class__,
|
|
|
|
| 559 |
return [opt_res, checkbox_group_info, checkbox, bttn]
|
| 560 |
|
| 561 |
def update_exp(exp_res):
|
| 562 |
+
try:
|
| 563 |
+
kwargs = {}
|
| 564 |
+
has_baseline = False
|
| 565 |
+
has_feature_mask = False
|
| 566 |
+
for k,v in exp_res['params'].items():
|
| 567 |
+
if "explainer" in k:
|
| 568 |
+
_key = k.split("explainer.")[1]
|
| 569 |
+
kwargs[_key] = v
|
| 570 |
+
if "baseline_fn" in _key:
|
| 571 |
+
has_baseline = True
|
| 572 |
+
if "feature_mask_fn" in _key:
|
| 573 |
+
has_feature_mask = True
|
| 574 |
+
|
| 575 |
+
# Reconstruct baseline object
|
| 576 |
+
if has_baseline:
|
| 577 |
+
method = kwargs['baseline_fn.method']
|
| 578 |
+
del kwargs['baseline_fn.method']
|
| 579 |
+
baseline_kwargs = {}
|
| 580 |
+
keys = list(kwargs.keys())
|
| 581 |
+
for k in keys:
|
| 582 |
+
v = kwargs[k]
|
| 583 |
+
if "baseline_fn" in k:
|
| 584 |
+
baseline_kwargs[k.split("baseline_fn.")[1]] = v
|
| 585 |
+
del kwargs[k]
|
| 586 |
+
if method == "mean":
|
| 587 |
+
baseline_kwargs['dim'] = 1 # Set arbitrary value
|
| 588 |
+
baseline_fn = BASELINE_FUNCTIONS_FOR_IMAGE[method](**baseline_kwargs)
|
| 589 |
+
kwargs['baseline_fn'] = baseline_fn
|
| 590 |
+
|
| 591 |
+
# Reconstruct feature_mask object
|
| 592 |
+
if has_feature_mask:
|
| 593 |
+
method = kwargs['feature_mask_fn.method']
|
| 594 |
+
del kwargs['feature_mask_fn.method']
|
| 595 |
+
mask_kwargs = {}
|
| 596 |
+
keys = list(kwargs.keys())
|
| 597 |
+
for k in keys:
|
| 598 |
+
v = kwargs[k]
|
| 599 |
+
if "feature_mask_fn" in k:
|
| 600 |
+
mask_kwargs[k.split("feature_mask_fn.")[1]] = v
|
| 601 |
+
del kwargs[k]
|
| 602 |
+
mask_fn = FEATURE_MASK_FUNCTIONS_FOR_IMAGE[method](**mask_kwargs)
|
| 603 |
+
kwargs['feature_mask_fn'] = mask_fn
|
| 604 |
+
|
| 605 |
+
kwargs['model'] = self.experiment.model
|
| 606 |
+
explainer = exp_res['class'](**kwargs)
|
| 607 |
+
_id = exp_res['id']
|
| 608 |
+
except Exception as e:
|
| 609 |
+
# If the optimization is failed, use the default parameter explainer as optimal
|
| 610 |
+
explainer = self.experiment.manager._explainers[self.default_exp_id]
|
| 611 |
+
|
| 612 |
self.experiment.manager._explainers.append(explainer)
|
| 613 |
self.experiment.manager._explainer_ids.append(_id)
|
| 614 |
|