Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import typing | |
| from inspect import signature | |
| from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union | |
| import torch | |
| from captum._utils.common import ( | |
| _format_baseline, | |
| _format_output, | |
| _format_tensor_into_tuples, | |
| _validate_input as _validate_input_basic, | |
| ) | |
| from captum._utils.typing import ( | |
| BaselineType, | |
| Literal, | |
| TargetType, | |
| TensorOrTupleOfTensorsGeneric, | |
| ) | |
| from captum.attr._utils.approximation_methods import SUPPORTED_METHODS | |
| from torch import Tensor | |
| if TYPE_CHECKING: | |
| from captum.attr._utils.attribution import GradientAttribution | |
| def _sum_rows(input: Tensor) -> Tensor: | |
| return input.reshape(input.shape[0], -1).sum(1) | |
| def _validate_target(num_samples: int, target: TargetType) -> None: | |
| if isinstance(target, list) or ( | |
| isinstance(target, torch.Tensor) and torch.numel(target) > 1 | |
| ): | |
| assert num_samples == len(target), ( | |
| "The number of samples provied in the" | |
| "input {} does not match with the number of targets. {}".format( | |
| num_samples, len(target) | |
| ) | |
| ) | |
| def _validate_input( | |
| inputs: Tuple[Tensor, ...], | |
| baselines: Tuple[Union[Tensor, int, float], ...], | |
| n_steps: int = 50, | |
| method: str = "riemann_trapezoid", | |
| draw_baseline_from_distrib: bool = False, | |
| ) -> None: | |
| _validate_input_basic(inputs, baselines, draw_baseline_from_distrib) | |
| assert ( | |
| n_steps >= 0 | |
| ), "The number of steps must be a positive integer. " "Given: {}".format(n_steps) | |
| assert ( | |
| method in SUPPORTED_METHODS | |
| ), "Approximation method must be one for the following {}. " "Given {}".format( | |
| SUPPORTED_METHODS, method | |
| ) | |
| def _validate_noise_tunnel_type( | |
| nt_type: str, supported_noise_tunnel_types: List[str] | |
| ) -> None: | |
| assert nt_type in supported_noise_tunnel_types, ( | |
| "Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`. " | |
| "Given {}".format(nt_type) | |
| ) | |
| def _format_input_baseline( | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| baselines: Union[Tensor, Tuple[Tensor, ...]], | |
| ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
| ... | |
| def _format_input_baseline( | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType | |
| ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: | |
| ... | |
| def _format_input_baseline( | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType | |
| ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: | |
| inputs = _format_tensor_into_tuples(inputs) | |
| baselines = _format_baseline(baselines, inputs) | |
| return inputs, baselines | |
| # This function can potentially be merged with the `format_baseline` function | |
| # however, since currently not all algorithms support baselines of type | |
| # callable this will be kept in a separate function. | |
| def _format_callable_baseline( | |
| baselines: Union[ | |
| None, | |
| Callable[..., Union[Tensor, Tuple[Tensor, ...]]], | |
| Tensor, | |
| Tuple[Tensor, ...], | |
| ], | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| ) -> Tuple[Tensor, ...]: | |
| ... | |
| def _format_callable_baseline( | |
| baselines: Union[ | |
| None, | |
| Callable[..., Union[Tensor, Tuple[Tensor, ...]]], | |
| Tensor, | |
| int, | |
| float, | |
| Tuple[Union[Tensor, int, float], ...], | |
| ], | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| ) -> Tuple[Union[Tensor, int, float], ...]: | |
| ... | |
| def _format_callable_baseline( | |
| baselines: Union[ | |
| None, | |
| Callable[..., Union[Tensor, Tuple[Tensor, ...]]], | |
| Tensor, | |
| int, | |
| float, | |
| Tuple[Union[Tensor, int, float], ...], | |
| ], | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| ) -> Tuple[Union[Tensor, int, float], ...]: | |
| if callable(baselines): | |
| # Note: this assumes that if baselines is a function and if it takes | |
| # arguments, then the first argument is the `inputs`. | |
| # This can be expanded in the future with better type checks | |
| baseline_parameters = signature(baselines).parameters | |
| if len(baseline_parameters) == 0: | |
| baselines = baselines() | |
| else: | |
| baselines = baselines(inputs) | |
| return _format_baseline(baselines, _format_tensor_into_tuples(inputs)) | |
| def _format_and_verify_strides( | |
| strides: Union[None, int, Tuple[int, ...], Tuple[Union[int, Tuple[int, ...]], ...]], | |
| inputs: Tuple[Tensor, ...], | |
| ) -> Tuple[Union[int, Tuple[int, ...]], ...]: | |
| # Formats strides, which are necessary for occlusion | |
| # Assumes inputs are already formatted (in tuple) | |
| if strides is None: | |
| strides = tuple(1 for input in inputs) | |
| if len(inputs) == 1 and not (isinstance(strides, tuple) and len(strides) == 1): | |
| strides = (strides,) # type: ignore | |
| assert isinstance(strides, tuple) and len(strides) == len( | |
| inputs | |
| ), "Strides must be provided for each input tensor." | |
| for i in range(len(inputs)): | |
| assert isinstance(strides[i], int) or ( | |
| isinstance(strides[i], tuple) | |
| and len(strides[i]) == len(inputs[i].shape) - 1 # type: ignore | |
| ), ( | |
| "Stride for input index {} is {}, which is invalid for input with " | |
| "shape {}. It must be either an int or a tuple with length equal to " | |
| "len(input_shape) - 1." | |
| ).format( | |
| i, strides[i], inputs[i].shape | |
| ) | |
| return strides | |
| def _format_and_verify_sliding_window_shapes( | |
| sliding_window_shapes: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]], | |
| inputs: Tuple[Tensor, ...], | |
| ) -> Tuple[Tuple[int, ...], ...]: | |
| # Formats shapes of sliding windows, which is necessary for occlusion | |
| # Assumes inputs is already formatted (in tuple) | |
| if isinstance(sliding_window_shapes[0], int): | |
| sliding_window_shapes = (sliding_window_shapes,) # type: ignore | |
| sliding_window_shapes: Tuple[Tuple[int, ...], ...] | |
| assert len(sliding_window_shapes) == len( | |
| inputs | |
| ), "Must provide sliding window dimensions for each input tensor." | |
| for i in range(len(inputs)): | |
| assert ( | |
| isinstance(sliding_window_shapes[i], tuple) | |
| and len(sliding_window_shapes[i]) == len(inputs[i].shape) - 1 | |
| ), ( | |
| "Occlusion shape for input index {} is {} but should be a tuple with " | |
| "{} dimensions." | |
| ).format( | |
| i, sliding_window_shapes[i], len(inputs[i].shape) - 1 | |
| ) | |
| return sliding_window_shapes | |
| def _compute_conv_delta_and_format_attrs( | |
| attr_algo: "GradientAttribution", | |
| return_convergence_delta: bool, | |
| attributions: Tuple[Tensor, ...], | |
| start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], | |
| end_point: Union[Tensor, Tuple[Tensor, ...]], | |
| additional_forward_args: Any, | |
| target: TargetType, | |
| is_inputs_tuple: Literal[False] = False, | |
| ) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
| ... | |
| def _compute_conv_delta_and_format_attrs( | |
| attr_algo: "GradientAttribution", | |
| return_convergence_delta: bool, | |
| attributions: Tuple[Tensor, ...], | |
| start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], | |
| end_point: Union[Tensor, Tuple[Tensor, ...]], | |
| additional_forward_args: Any, | |
| target: TargetType, | |
| is_inputs_tuple: Literal[True], | |
| ) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: | |
| ... | |
| # FIXME: GradientAttribution is provided as a string due to a circular import. | |
| # This should be fixed when common is refactored into separate files. | |
| def _compute_conv_delta_and_format_attrs( | |
| attr_algo: "GradientAttribution", | |
| return_convergence_delta: bool, | |
| attributions: Tuple[Tensor, ...], | |
| start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], | |
| end_point: Union[Tensor, Tuple[Tensor, ...]], | |
| additional_forward_args: Any, | |
| target: TargetType, | |
| is_inputs_tuple: bool = False, | |
| ) -> Union[ | |
| Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] | |
| ]: | |
| if return_convergence_delta: | |
| # computes convergence error | |
| delta = attr_algo.compute_convergence_delta( | |
| attributions, | |
| start_point, | |
| end_point, | |
| additional_forward_args=additional_forward_args, | |
| target=target, | |
| ) | |
| return _format_output(is_inputs_tuple, attributions), delta | |
| else: | |
| return _format_output(is_inputs_tuple, attributions) | |
| def _tensorize_baseline( | |
| inputs: Tuple[Tensor, ...], baselines: Tuple[Union[int, float, Tensor], ...] | |
| ) -> Tuple[Tensor, ...]: | |
| def _tensorize_single_baseline(baseline, input): | |
| if isinstance(baseline, (int, float)): | |
| return torch.full_like(input, baseline) | |
| if input.shape[0] > baseline.shape[0] and baseline.shape[0] == 1: | |
| return torch.cat([baseline] * input.shape[0]) | |
| return baseline | |
| assert isinstance(inputs, tuple) and isinstance(baselines, tuple), ( | |
| "inputs and baselines must" | |
| "have tuple type but found baselines: {} and inputs: {}".format( | |
| type(baselines), type(inputs) | |
| ) | |
| ) | |
| return tuple( | |
| _tensorize_single_baseline(baseline, input) | |
| for baseline, input in zip(baselines, inputs) | |
| ) | |
| def _reshape_and_sum( | |
| tensor_input: Tensor, num_steps: int, num_examples: int, layer_size: Tuple[int, ...] | |
| ) -> Tensor: | |
| # Used for attribution methods which perform integration | |
| # Sums across integration steps by reshaping tensor to | |
| # (num_steps, num_examples, (layer_size)) and summing over | |
| # dimension 0. Returns a tensor of size (num_examples, (layer_size)) | |
| return torch.sum( | |
| tensor_input.reshape((num_steps, num_examples) + layer_size), dim=0 | |
| ) | |
| def _call_custom_attribution_func( | |
| custom_attribution_func: Callable[..., Tuple[Tensor, ...]], | |
| multipliers: Tuple[Tensor, ...], | |
| inputs: Tuple[Tensor, ...], | |
| baselines: Tuple[Tensor, ...], | |
| ) -> Tuple[Tensor, ...]: | |
| assert callable(custom_attribution_func), ( | |
| "`custom_attribution_func`" | |
| " must be a callable function but {} provided".format( | |
| type(custom_attribution_func) | |
| ) | |
| ) | |
| custom_attr_func_params = signature(custom_attribution_func).parameters | |
| if len(custom_attr_func_params) == 1: | |
| return custom_attribution_func(multipliers) | |
| elif len(custom_attr_func_params) == 2: | |
| return custom_attribution_func(multipliers, inputs) | |
| elif len(custom_attr_func_params) == 3: | |
| return custom_attribution_func(multipliers, inputs, baselines) | |
| else: | |
| raise AssertionError( | |
| "`custom_attribution_func` must take at least one and at most 3 arguments." | |
| ) | |
| def _find_output_mode_and_verify( | |
| initial_eval: Union[int, float, Tensor], | |
| num_examples: int, | |
| perturbations_per_eval: int, | |
| feature_mask: Union[None, TensorOrTupleOfTensorsGeneric], | |
| ) -> bool: | |
| """ | |
| This method identifies whether the model outputs a single output for a batch | |
| (agg_output_mode = True) or whether it outputs a single output per example | |
| (agg_output_mode = False) and returns agg_output_mode. The method also | |
| verifies that perturbations_per_eval is 1 in the case that agg_output_mode is True | |
| and also verifies that the first dimension of each feature mask if the model | |
| returns a single output for a batch. | |
| """ | |
| if isinstance(initial_eval, (int, float)) or ( | |
| isinstance(initial_eval, torch.Tensor) | |
| and ( | |
| len(initial_eval.shape) == 0 | |
| or (num_examples > 1 and initial_eval.numel() == 1) | |
| ) | |
| ): | |
| agg_output_mode = True | |
| assert ( | |
| perturbations_per_eval == 1 | |
| ), "Cannot have perturbations_per_eval > 1 when function returns scalar." | |
| if feature_mask is not None: | |
| for single_mask in feature_mask: | |
| assert single_mask.shape[0] == 1, ( | |
| "Cannot provide different masks for each example when function " | |
| "returns a scalar." | |
| ) | |
| else: | |
| agg_output_mode = False | |
| assert ( | |
| isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1 | |
| ), "Target should identify a single element in the model output." | |
| return agg_output_mode | |
| def _construct_default_feature_mask( | |
| inputs: Tuple[Tensor, ...] | |
| ) -> Tuple[Tuple[Tensor, ...], int]: | |
| feature_mask = [] | |
| current_num_features = 0 | |
| for i in range(len(inputs)): | |
| num_features = torch.numel(inputs[i][0]) | |
| feature_mask.append( | |
| current_num_features | |
| + torch.reshape( | |
| torch.arange(num_features, device=inputs[i].device), | |
| inputs[i][0:1].shape, | |
| ) | |
| ) | |
| current_num_features += num_features | |
| total_features = current_num_features | |
| feature_mask = tuple(feature_mask) | |
| return feature_mask, total_features | |