Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union | |
| from captum.attr import ( | |
| Deconvolution, | |
| DeepLift, | |
| FeatureAblation, | |
| GuidedBackprop, | |
| InputXGradient, | |
| IntegratedGradients, | |
| Occlusion, | |
| Saliency, | |
| ) | |
| from captum.attr._utils.approximation_methods import SUPPORTED_METHODS | |
| class NumberConfig(NamedTuple): | |
| value: int = 1 | |
| limit: Tuple[Optional[int], Optional[int]] = (None, None) | |
| type: str = "number" | |
| class StrEnumConfig(NamedTuple): | |
| value: str | |
| limit: List[str] | |
| type: str = "enum" | |
| class StrConfig(NamedTuple): | |
| value: str | |
| type: str = "string" | |
| Config = Union[NumberConfig, StrEnumConfig, StrConfig] | |
| SUPPORTED_ATTRIBUTION_METHODS = [ | |
| Deconvolution, | |
| DeepLift, | |
| GuidedBackprop, | |
| InputXGradient, | |
| IntegratedGradients, | |
| Saliency, | |
| FeatureAblation, | |
| Occlusion, | |
| ] | |
| class ConfigParameters(NamedTuple): | |
| params: Dict[str, Config] | |
| help_info: Optional[str] = None # TODO fill out help for each method | |
| post_process: Optional[Dict[str, Callable[[Any], Any]]] = None | |
| ATTRIBUTION_NAMES_TO_METHODS = { | |
| # mypy bug - treating it as a type instead of a class | |
| cls.get_name(): cls # type: ignore | |
| for cls in SUPPORTED_ATTRIBUTION_METHODS | |
| } | |
| def _str_to_tuple(s): | |
| if isinstance(s, tuple): | |
| return s | |
| return tuple([int(i) for i in s.split()]) | |
| ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = { | |
| IntegratedGradients.get_name(): ConfigParameters( | |
| params={ | |
| "n_steps": NumberConfig(value=25, limit=(2, None)), | |
| "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"), | |
| }, | |
| post_process={"n_steps": int}, | |
| ), | |
| FeatureAblation.get_name(): ConfigParameters( | |
| params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))}, | |
| ), | |
| Occlusion.get_name(): ConfigParameters( | |
| params={ | |
| "sliding_window_shapes": StrConfig(value=""), | |
| "strides": StrConfig(value=""), | |
| "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)), | |
| }, | |
| post_process={ | |
| "sliding_window_shapes": _str_to_tuple, | |
| "strides": _str_to_tuple, | |
| "perturbations_per_eval": int, | |
| }, | |
| ), | |
| } | |