Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| from typing import Any, Callable, Tuple, Union | |
| import torch | |
| from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric | |
| from captum.attr._core.feature_ablation import FeatureAblation | |
| from captum.log import log_usage | |
| from torch import Tensor | |
| def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: | |
| n = x.size(0) | |
| assert n > 1, "cannot permute features with batch_size = 1" | |
| perm = torch.randperm(n) | |
| no_perm = torch.arange(n) | |
| while (perm == no_perm).all(): | |
| perm = torch.randperm(n) | |
| return (x[perm] * feature_mask.to(dtype=x.dtype)) + ( | |
| x * feature_mask.bitwise_not().to(dtype=x.dtype) | |
| ) | |
| class FeaturePermutation(FeatureAblation): | |
| r""" | |
| A perturbation based approach to compute attribution, which | |
| takes each input feature, permutes the feature values within a batch, | |
| and computes the difference between original and shuffled outputs for | |
| the given batch. This difference signifies the feature importance | |
| for the permuted feature. | |
| Example pseudocode for the algorithm is as follows:: | |
| perm_feature_importance(batch): | |
| importance = dict() | |
| baseline_error = error_metric(model(batch), batch_labels) | |
| for each feature: | |
| permute this feature across the batch | |
| error = error_metric(model(permuted_batch), batch_labels) | |
| importance[feature] = baseline_error - error | |
| "un-permute" the feature across the batch | |
| return importance | |
| It should be noted that the `error_metric` must be called in the | |
| `forward_func`. You do not need to have an error metric, e.g. you | |
| could simply return the logits (the model output), but this may or may | |
| not provide a meaningful attribution. | |
| This method, unlike other attribution methods, requires a batch | |
| of examples to compute attributions and cannot be performed on a single example. | |
| By default, each scalar value within | |
| each input tensor is taken as a feature and shuffled independently. Passing | |
| a feature mask, allows grouping features to be shuffled together. | |
| Each input scalar in the group will be given the same attribution value | |
| equal to the change in target as a result of shuffling the entire feature | |
| group. | |
| The forward function can either return a scalar per example, or a single | |
| scalar for the full batch. If a single scalar is returned for the batch, | |
| `perturbations_per_eval` must be 1, and the returned attributions will have | |
| first dimension 1, corresponding to feature importance across all | |
| examples in the batch. | |
| More information can be found in the permutation feature | |
| importance algorithm description here: | |
| https://christophm.github.io/interpretable-ml-book/feature-importance.html | |
| """ | |
| def __init__( | |
| self, forward_func: Callable, perm_func: Callable = _permute_feature | |
| ) -> None: | |
| r""" | |
| Args: | |
| forward_func (callable): The forward function of the model or | |
| any modification of it | |
| perm_func (callable, optional): A function that accepts a batch of | |
| inputs and a feature mask, and "permutes" the feature using | |
| feature mask across the batch. This defaults to a function | |
| which applies a random permutation, this argument only needs | |
| to be provided if a custom permutation behavior is desired. | |
| Default: `_permute_feature` | |
| """ | |
| FeatureAblation.__init__(self, forward_func=forward_func) | |
| self.perm_func = perm_func | |
| # suppressing error caused by the child class not having a matching | |
| # signature to the parent | |
| def attribute( # type: ignore | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| target: TargetType = None, | |
| additional_forward_args: Any = None, | |
| feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, | |
| perturbations_per_eval: int = 1, | |
| show_progress: bool = False, | |
| **kwargs: Any, | |
| ) -> TensorOrTupleOfTensorsGeneric: | |
| r""" | |
| This function is almost equivalent to `FeatureAblation.attribute`. The | |
| main difference is the way ablated examples are generated. Specifically | |
| they are generated through the `perm_func`, as we set the baselines for | |
| `FeatureAblation.attribute` to None. | |
| Args: | |
| inputs (tensor or tuple of tensors): Input for which | |
| permutation attributions are computed. If | |
| forward_func takes a single tensor as input, a | |
| single input tensor should be provided. If | |
| forward_func takes multiple tensors as input, a | |
| tuple of the input tensors should be provided. It is | |
| assumed that for all given input tensors, dimension | |
| 0 corresponds to the number of examples (aka batch | |
| size), and if multiple input tensors are provided, | |
| the examples must be aligned appropriately. | |
| target (int, tuple, tensor or list, optional): Output indices for | |
| which difference is computed (for classification cases, | |
| this is usually the target class). | |
| If the network returns a scalar value per example, | |
| no target index is necessary. | |
| For general 2D outputs, targets can be either: | |
| - a single integer or a tensor containing a single | |
| integer, which is applied to all input examples | |
| - a list of integers or a 1D tensor, with length matching | |
| the number of examples in inputs (dim 0). Each integer | |
| is applied as the target for the corresponding example. | |
| For outputs with > 2 dimensions, targets can be either: | |
| - A single tuple, which contains #output_dims - 1 | |
| elements. This target index is applied to all examples. | |
| - A list of tuples with length equal to the number of | |
| examples in inputs (dim 0), and each tuple containing | |
| #output_dims - 1 elements. Each tuple is applied as the | |
| target for the corresponding example. | |
| Default: None | |
| additional_forward_args (any, optional): If the forward function | |
| requires additional arguments other than the inputs for | |
| which attributions should not be computed, this argument | |
| can be provided. It must be either a single additional | |
| argument of a Tensor or arbitrary (non-tuple) type or a | |
| tuple containing multiple additional arguments including | |
| tensors or any arbitrary python types. These arguments | |
| are provided to forward_func in order following the | |
| arguments in inputs. | |
| For a tensor, the first dimension of the tensor must | |
| correspond to the number of examples. For all other types, | |
| the given argument is used for all forward evaluations. | |
| Note that attributions are not computed with respect | |
| to these arguments. | |
| Default: None | |
| feature_mask (tensor or tuple of tensors, optional): | |
| feature_mask defines a mask for the input, grouping | |
| features which should be ablated together. feature_mask | |
| should contain the same number of tensors as inputs. | |
| Each tensor should be the same size as the | |
| corresponding input or broadcastable to match the | |
| input tensor. Each tensor should contain integers in | |
| the range 0 to num_features - 1, and indices | |
| corresponding to the same feature should have the | |
| same value. Note that features within each input | |
| tensor are ablated independently (not across | |
| tensors). | |
| The first dimension of each mask must be 1, as we require | |
| to have the same group of features for each input sample. | |
| If None, then a feature mask is constructed which assigns | |
| each scalar within a tensor as a separate feature, which | |
| is permuted independently. | |
| Default: None | |
| perturbations_per_eval (int, optional): Allows permutations | |
| of multiple features to be processed simultaneously | |
| in one call to forward_fn. Each forward pass will | |
| contain a maximum of perturbations_per_eval * #examples | |
| samples. For DataParallel models, each batch is | |
| split among the available devices, so evaluations on | |
| each available device contain at most | |
| (perturbations_per_eval * #examples) / num_devices | |
| samples. | |
| If the forward function returns a single scalar per batch, | |
| perturbations_per_eval must be set to 1. | |
| Default: 1 | |
| show_progress (bool, optional): Displays the progress of computation. | |
| It will try to use tqdm if available for advanced features | |
| (e.g. time estimation). Otherwise, it will fallback to | |
| a simple output of progress. | |
| Default: False | |
| **kwargs (Any, optional): Any additional arguments used by child | |
| classes of FeatureAblation (such as Occlusion) to construct | |
| ablations. These arguments are ignored when using | |
| FeatureAblation directly. | |
| Default: None | |
| Returns: | |
| *tensor* or tuple of *tensors* of **attributions**: | |
| - **attributions** (*tensor* or tuple of *tensors*): | |
| The attributions with respect to each input feature. | |
| If the forward function returns | |
| a scalar value per example, attributions will be | |
| the same size as the provided inputs, with each value | |
| providing the attribution of the corresponding input index. | |
| If the forward function returns a scalar per batch, then | |
| attribution tensor(s) will have first dimension 1 and | |
| the remaining dimensions will match the input. | |
| If a single tensor is provided as inputs, a single tensor is | |
| returned. If a tuple of tensors is provided for inputs, | |
| a tuple of corresponding sized tensors is returned. | |
| Examples:: | |
| >>> # SimpleClassifier takes a single input tensor of size Nx4x4, | |
| >>> # and returns an Nx3 tensor of class probabilities. | |
| >>> net = SimpleClassifier() | |
| >>> # Generating random input with size 10 x 4 x 4 | |
| >>> input = torch.randn(10, 4, 4) | |
| >>> # Defining FeaturePermutation interpreter | |
| >>> feature_perm = FeaturePermutation(net) | |
| >>> # Computes permutation attribution, shuffling each of the 16 | |
| >>> # scalar input independently. | |
| >>> attr = feature_perm.attribute(input, target=1) | |
| >>> # Alternatively, we may want to permute features in groups, e.g. | |
| >>> # grouping each 2x2 square of the inputs and shuffling them together. | |
| >>> # This can be done by creating a feature mask as follows, which | |
| >>> # defines the feature groups, e.g.: | |
| >>> # +---+---+---+---+ | |
| >>> # | 0 | 0 | 1 | 1 | | |
| >>> # +---+---+---+---+ | |
| >>> # | 0 | 0 | 1 | 1 | | |
| >>> # +---+---+---+---+ | |
| >>> # | 2 | 2 | 3 | 3 | | |
| >>> # +---+---+---+---+ | |
| >>> # | 2 | 2 | 3 | 3 | | |
| >>> # +---+---+---+---+ | |
| >>> # With this mask, all inputs with the same value are shuffled | |
| >>> # simultaneously, and the attribution for each input in the same | |
| >>> # group (0, 1, 2, and 3) per example are the same. | |
| >>> # The attributions can be calculated as follows: | |
| >>> # feature mask has dimensions 1 x 4 x 4 | |
| >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], | |
| >>> [2,2,3,3],[2,2,3,3]]]) | |
| >>> attr = feature_perm.attribute(input, target=1, | |
| >>> feature_mask=feature_mask) | |
| """ | |
| return FeatureAblation.attribute.__wrapped__( | |
| self, | |
| inputs, | |
| baselines=None, | |
| target=target, | |
| additional_forward_args=additional_forward_args, | |
| feature_mask=feature_mask, | |
| perturbations_per_eval=perturbations_per_eval, | |
| show_progress=show_progress, | |
| **kwargs, | |
| ) | |
| def _construct_ablated_input( | |
| self, | |
| expanded_input: Tensor, | |
| input_mask: Tensor, | |
| baseline: Union[int, float, Tensor], | |
| start_feature: int, | |
| end_feature: int, | |
| **kwargs: Any, | |
| ) -> Tuple[Tensor, Tensor]: | |
| r""" | |
| This function permutes the features of `expanded_input` with a given | |
| feature mask and feature range. Permutation occurs via calling | |
| `self.perm_func` across each batch within `expanded_input`. As with | |
| `FeatureAblation._construct_ablated_input`: | |
| - `expanded_input.shape = (num_features, num_examples, ...)` | |
| - `num_features = end_feature - start_feature` (i.e. start and end is a | |
| half-closed interval) | |
| - `input_mask` is a tensor of the same shape as one input, which | |
| describes the locations of each feature via their "index" | |
| Since `baselines` is set to None for `FeatureAblation.attribute, this | |
| will be the zero tensor, however, it is not used. | |
| """ | |
| assert input_mask.shape[0] == 1, ( | |
| "input_mask.shape[0] != 1: pass in one mask in order to permute" | |
| "the same features for each input" | |
| ) | |
| current_mask = torch.stack( | |
| [input_mask == j for j in range(start_feature, end_feature)], dim=0 | |
| ).bool() | |
| output = torch.stack( | |
| [ | |
| self.perm_func(x, mask.squeeze(0)) | |
| for x, mask in zip(expanded_input, current_mask) | |
| ] | |
| ) | |
| return output, current_mask | |