Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import inspect | |
| import math | |
| import typing | |
| import warnings | |
| from typing import Any, Callable, cast, List, Optional, Tuple, Union | |
| import torch | |
| from captum._utils.common import ( | |
| _expand_additional_forward_args, | |
| _expand_target, | |
| _flatten_tensor_or_tuple, | |
| _format_output, | |
| _format_tensor_into_tuples, | |
| _is_tuple, | |
| _reduce_list, | |
| _run_forward, | |
| ) | |
| from captum._utils.models.linear_model import SkLearnLasso | |
| from captum._utils.models.model import Model | |
| from captum._utils.progress import progress | |
| from captum._utils.typing import ( | |
| BaselineType, | |
| Literal, | |
| TargetType, | |
| TensorOrTupleOfTensorsGeneric, | |
| ) | |
| from captum.attr._utils.attribution import PerturbationAttribution | |
| from captum.attr._utils.batching import _batch_example_iterator | |
| from captum.attr._utils.common import ( | |
| _construct_default_feature_mask, | |
| _format_input_baseline, | |
| ) | |
| from captum.log import log_usage | |
| from torch import Tensor | |
| from torch.nn import CosineSimilarity | |
| from torch.utils.data import DataLoader, TensorDataset | |
| class LimeBase(PerturbationAttribution): | |
| r""" | |
| Lime is an interpretability method that trains an interpretable surrogate model | |
| by sampling points around a specified input example and using model evaluations | |
| at these points to train a simpler interpretable 'surrogate' model, such as a | |
| linear model. | |
| LimeBase provides a generic framework to train a surrogate interpretable model. | |
| This differs from most other attribution methods, since the method returns a | |
| representation of the interpretable model (e.g. coefficients of the linear model). | |
| For a similar interface to other perturbation-based attribution methods, please use | |
| the Lime child class, which defines specific transformations for the interpretable | |
| model. | |
| LimeBase allows sampling points in either the interpretable space or the original | |
| input space to train the surrogate model. The interpretable space is a feature | |
| vector used to train the surrogate interpretable model; this feature space is often | |
| of smaller dimensionality than the original feature space in order for the surrogate | |
| model to be more interpretable. | |
| If sampling in the interpretable space, a transformation function must be provided | |
| to define how a vector sampled in the interpretable space can be transformed into | |
| an example in the original input space. If sampling in the original input space, a | |
| transformation function must be provided to define how the input can be transformed | |
| into its interpretable vector representation. | |
| More details regarding LIME can be found in the original paper: | |
| https://arxiv.org/abs/1602.04938 | |
| """ | |
| def __init__( | |
| self, | |
| forward_func: Callable, | |
| interpretable_model: Model, | |
| similarity_func: Callable, | |
| perturb_func: Callable, | |
| perturb_interpretable_space: bool, | |
| from_interp_rep_transform: Optional[Callable], | |
| to_interp_rep_transform: Optional[Callable], | |
| ) -> None: | |
| r""" | |
| Args: | |
| forward_func (callable): The forward function of the model or any | |
| modification of it. If a batch is provided as input for | |
| attribution, it is expected that forward_func returns a scalar | |
| representing the entire batch. | |
| interpretable_model (Model): Model object to train interpretable model. | |
| A Model object provides a `fit` method to train the model, | |
| given a dataloader, with batches containing three tensors: | |
| - interpretable_inputs: Tensor | |
| [2D num_samples x num_interp_features], | |
| - expected_outputs: Tensor [1D num_samples], | |
| - weights: Tensor [1D num_samples] | |
| The model object must also provide a `representation` method to | |
| access the appropriate coefficients or representation of the | |
| interpretable model after fitting. | |
| Some predefined interpretable linear models are provided in | |
| captum._utils.models.linear_model including wrappers around | |
| SkLearn linear models as well as SGD-based PyTorch linear | |
| models. | |
| Note that calling fit multiple times should retrain the | |
| interpretable model, each attribution call reuses | |
| the same given interpretable model object. | |
| similarity_func (callable): Function which takes a single sample | |
| along with its corresponding interpretable representation | |
| and returns the weight of the interpretable sample for | |
| training interpretable model. Weight is generally | |
| determined based on similarity to the original input. | |
| The original paper refers to this as a similarity kernel. | |
| The expected signature of this callable is: | |
| >>> similarity_func( | |
| >>> original_input: Tensor or tuple of Tensors, | |
| >>> perturbed_input: Tensor or tuple of Tensors, | |
| >>> perturbed_interpretable_input: | |
| >>> Tensor [2D 1 x num_interp_features], | |
| >>> **kwargs: Any | |
| >>> ) -> float or Tensor containing float scalar | |
| perturbed_input and original_input will be the same type and | |
| contain tensors of the same shape (regardless of whether or not | |
| the sampling function returns inputs in the interpretable | |
| space). original_input is the same as the input provided | |
| when calling attribute. | |
| All kwargs passed to the attribute method are | |
| provided as keyword arguments (kwargs) to this callable. | |
| perturb_func (callable): Function which returns a single | |
| sampled input, generally a perturbation of the original | |
| input, which is used to train the interpretable surrogate | |
| model. Function can return samples in either | |
| the original input space (matching type and tensor shapes | |
| of original input) or in the interpretable input space, | |
| which is a vector containing the intepretable features. | |
| Alternatively, this function can return a generator | |
| yielding samples to train the interpretable surrogate | |
| model, and n_samples perturbations will be sampled | |
| from this generator. | |
| The expected signature of this callable is: | |
| >>> perturb_func( | |
| >>> original_input: Tensor or tuple of Tensors, | |
| >>> **kwargs: Any | |
| >>> ) -> Tensor or tuple of Tensors or | |
| >>> generator yielding tensor or tuple of Tensors | |
| All kwargs passed to the attribute method are | |
| provided as keyword arguments (kwargs) to this callable. | |
| Returned sampled input should match the input type (Tensor | |
| or Tuple of Tensor and corresponding shapes) if | |
| perturb_interpretable_space = False. If | |
| perturb_interpretable_space = True, the return type should | |
| be a single tensor of shape 1 x num_interp_features, | |
| corresponding to the representation of the | |
| sample to train the interpretable model. | |
| All kwargs passed to the attribute method are | |
| provided as keyword arguments (kwargs) to this callable. | |
| perturb_interpretable_space (bool): Indicates whether | |
| perturb_func returns a sample in the interpretable space | |
| (tensor of shape 1 x num_interp_features) or a sample | |
| in the original space, matching the format of the original | |
| input. Once sampled, inputs can be converted to / from | |
| the interpretable representation with either | |
| to_interp_rep_transform or from_interp_rep_transform. | |
| from_interp_rep_transform (callable): Function which takes a | |
| single sampled interpretable representation (tensor | |
| of shape 1 x num_interp_features) and returns | |
| the corresponding representation in the input space | |
| (matching shapes of original input to attribute). | |
| This argument is necessary if perturb_interpretable_space | |
| is True, otherwise None can be provided for this argument. | |
| The expected signature of this callable is: | |
| >>> from_interp_rep_transform( | |
| >>> curr_sample: Tensor [2D 1 x num_interp_features] | |
| >>> original_input: Tensor or Tuple of Tensors, | |
| >>> **kwargs: Any | |
| >>> ) -> Tensor or tuple of Tensors | |
| Returned sampled input should match the type of original_input | |
| and corresponding tensor shapes. | |
| All kwargs passed to the attribute method are | |
| provided as keyword arguments (kwargs) to this callable. | |
| to_interp_rep_transform (callable): Function which takes a | |
| sample in the original input space and converts to | |
| its interpretable representation (tensor | |
| of shape 1 x num_interp_features). | |
| This argument is necessary if perturb_interpretable_space | |
| is False, otherwise None can be provided for this argument. | |
| The expected signature of this callable is: | |
| >>> to_interp_rep_transform( | |
| >>> curr_sample: Tensor or Tuple of Tensors, | |
| >>> original_input: Tensor or Tuple of Tensors, | |
| >>> **kwargs: Any | |
| >>> ) -> Tensor [2D 1 x num_interp_features] | |
| curr_sample will match the type of original_input | |
| and corresponding tensor shapes. | |
| All kwargs passed to the attribute method are | |
| provided as keyword arguments (kwargs) to this callable. | |
| """ | |
| PerturbationAttribution.__init__(self, forward_func) | |
| self.interpretable_model = interpretable_model | |
| self.similarity_func = similarity_func | |
| self.perturb_func = perturb_func | |
| self.perturb_interpretable_space = perturb_interpretable_space | |
| self.from_interp_rep_transform = from_interp_rep_transform | |
| self.to_interp_rep_transform = to_interp_rep_transform | |
| if self.perturb_interpretable_space: | |
| assert ( | |
| self.from_interp_rep_transform is not None | |
| ), "Must provide transform from interpretable space to original input space" | |
| " when sampling from interpretable space." | |
| else: | |
| assert ( | |
| self.to_interp_rep_transform is not None | |
| ), "Must provide transform from original input space to interpretable space" | |
| def attribute( | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| target: TargetType = None, | |
| additional_forward_args: Any = None, | |
| n_samples: int = 50, | |
| perturbations_per_eval: int = 1, | |
| show_progress: bool = False, | |
| **kwargs, | |
| ) -> Tensor: | |
| r""" | |
| This method attributes the output of the model with given target index | |
| (in case it is provided, otherwise it assumes that output is a | |
| scalar) to the inputs of the model using the approach described above. | |
| It trains an interpretable model and returns a representation of the | |
| interpretable model. | |
| It is recommended to only provide a single example as input (tensors | |
| with first dimension or batch size = 1). This is because LIME is generally | |
| used for sample-based interpretability, training a separate interpretable | |
| model to explain a model's prediction on each individual example. | |
| A batch of inputs can be provided as inputs only if forward_func | |
| returns a single value per batch (e.g. loss). | |
| The interpretable feature representation should still have shape | |
| 1 x num_interp_features, corresponding to the interpretable | |
| representation for the full batch, and perturbations_per_eval | |
| must be set to 1. | |
| Args: | |
| inputs (tensor or tuple of tensors): Input for which LIME | |
| is computed. If forward_func takes a single | |
| tensor as input, a single input tensor should be provided. | |
| If forward_func takes multiple tensors as input, a tuple | |
| of the input tensors should be provided. It is assumed | |
| that for all given input tensors, dimension 0 corresponds | |
| to the number of examples, and if multiple input tensors | |
| are provided, the examples must be aligned appropriately. | |
| target (int, tuple, tensor or list, optional): Output indices for | |
| which surrogate model is trained | |
| (for classification cases, | |
| this is usually the target class). | |
| If the network returns a scalar value per example, | |
| no target index is necessary. | |
| For general 2D outputs, targets can be either: | |
| - a single integer or a tensor containing a single | |
| integer, which is applied to all input examples | |
| - a list of integers or a 1D tensor, with length matching | |
| the number of examples in inputs (dim 0). Each integer | |
| is applied as the target for the corresponding example. | |
| For outputs with > 2 dimensions, targets can be either: | |
| - A single tuple, which contains #output_dims - 1 | |
| elements. This target index is applied to all examples. | |
| - A list of tuples with length equal to the number of | |
| examples in inputs (dim 0), and each tuple containing | |
| #output_dims - 1 elements. Each tuple is applied as the | |
| target for the corresponding example. | |
| Default: None | |
| additional_forward_args (any, optional): If the forward function | |
| requires additional arguments other than the inputs for | |
| which attributions should not be computed, this argument | |
| can be provided. It must be either a single additional | |
| argument of a Tensor or arbitrary (non-tuple) type or a | |
| tuple containing multiple additional arguments including | |
| tensors or any arbitrary python types. These arguments | |
| are provided to forward_func in order following the | |
| arguments in inputs. | |
| For a tensor, the first dimension of the tensor must | |
| correspond to the number of examples. For all other types, | |
| the given argument is used for all forward evaluations. | |
| Note that attributions are not computed with respect | |
| to these arguments. | |
| Default: None | |
| n_samples (int, optional): The number of samples of the original | |
| model used to train the surrogate interpretable model. | |
| Default: `50` if `n_samples` is not provided. | |
| perturbations_per_eval (int, optional): Allows multiple samples | |
| to be processed simultaneously in one call to forward_fn. | |
| Each forward pass will contain a maximum of | |
| perturbations_per_eval * #examples samples. | |
| For DataParallel models, each batch is split among the | |
| available devices, so evaluations on each available | |
| device contain at most | |
| (perturbations_per_eval * #examples) / num_devices | |
| samples. | |
| If the forward function returns a single scalar per batch, | |
| perturbations_per_eval must be set to 1. | |
| Default: 1 | |
| show_progress (bool, optional): Displays the progress of computation. | |
| It will try to use tqdm if available for advanced features | |
| (e.g. time estimation). Otherwise, it will fallback to | |
| a simple output of progress. | |
| Default: False | |
| **kwargs (Any, optional): Any additional arguments necessary for | |
| sampling and transformation functions (provided to | |
| constructor). | |
| Default: None | |
| Returns: | |
| **interpretable model representation**: | |
| - **interpretable model representation* (*Any*): | |
| A representation of the interpretable model trained. The return | |
| type matches the return type of train_interpretable_model_func. | |
| For example, this could contain coefficients of a | |
| linear surrogate model. | |
| Examples:: | |
| >>> # SimpleClassifier takes a single input tensor of | |
| >>> # float features with size N x 5, | |
| >>> # and returns an Nx3 tensor of class probabilities. | |
| >>> net = SimpleClassifier() | |
| >>> | |
| >>> # We will train an interpretable model with the same | |
| >>> # features by simply sampling with added Gaussian noise | |
| >>> # to the inputs and training a model to predict the | |
| >>> # score of the target class. | |
| >>> | |
| >>> # For interpretable model training, we will use sklearn | |
| >>> # linear model in this example. We have provided wrappers | |
| >>> # around sklearn linear models to fit the Model interface. | |
| >>> # Any arguments provided to the sklearn constructor can also | |
| >>> # be provided to the wrapper, e.g.: | |
| >>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0) | |
| >>> from captum._utils.models.linear_model import SkLearnLinearModel | |
| >>> | |
| >>> | |
| >>> # Define similarity kernel (exponential kernel based on L2 norm) | |
| >>> def similarity_kernel( | |
| >>> original_input: Tensor, | |
| >>> perturbed_input: Tensor, | |
| >>> perturbed_interpretable_input: Tensor, | |
| >>> **kwargs)->Tensor: | |
| >>> # kernel_width will be provided to attribute as a kwarg | |
| >>> kernel_width = kwargs["kernel_width"] | |
| >>> l2_dist = torch.norm(original_input - perturbed_input) | |
| >>> return torch.exp(- (l2_dist**2) / (kernel_width**2)) | |
| >>> | |
| >>> | |
| >>> # Define sampling function | |
| >>> # This function samples in original input space | |
| >>> def perturb_func( | |
| >>> original_input: Tensor, | |
| >>> **kwargs)->Tensor: | |
| >>> return original_input + torch.randn_like(original_input) | |
| >>> | |
| >>> # For this example, we are setting the interpretable input to | |
| >>> # match the model input, so the to_interp_rep_transform | |
| >>> # function simply returns the input. In most cases, the interpretable | |
| >>> # input will be different and may have a smaller feature set, so | |
| >>> # an appropriate transformation function should be provided. | |
| >>> | |
| >>> def to_interp_transform(curr_sample, original_inp, | |
| >>> **kwargs): | |
| >>> return curr_sample | |
| >>> | |
| >>> # Generating random input with size 1 x 5 | |
| >>> input = torch.randn(1, 5) | |
| >>> # Defining LimeBase interpreter | |
| >>> lime_attr = LimeBase(net, | |
| SkLearnLinearModel("linear_model.Ridge"), | |
| similarity_func=similarity_kernel, | |
| perturb_func=perturb_func, | |
| perturb_interpretable_space=False, | |
| from_interp_rep_transform=None, | |
| to_interp_rep_transform=to_interp_transform) | |
| >>> # Computes interpretable model, returning coefficients of linear | |
| >>> # model. | |
| >>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) | |
| """ | |
| with torch.no_grad(): | |
| inp_tensor = ( | |
| cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0] | |
| ) | |
| device = inp_tensor.device | |
| interpretable_inps = [] | |
| similarities = [] | |
| outputs = [] | |
| curr_model_inputs = [] | |
| expanded_additional_args = None | |
| expanded_target = None | |
| perturb_generator = None | |
| if inspect.isgeneratorfunction(self.perturb_func): | |
| perturb_generator = self.perturb_func(inputs, **kwargs) | |
| if show_progress: | |
| attr_progress = progress( | |
| total=math.ceil(n_samples / perturbations_per_eval), | |
| desc=f"{self.get_name()} attribution", | |
| ) | |
| attr_progress.update(0) | |
| batch_count = 0 | |
| for _ in range(n_samples): | |
| if perturb_generator: | |
| try: | |
| curr_sample = next(perturb_generator) | |
| except StopIteration: | |
| warnings.warn( | |
| "Generator completed prior to given n_samples iterations!" | |
| ) | |
| break | |
| else: | |
| curr_sample = self.perturb_func(inputs, **kwargs) | |
| batch_count += 1 | |
| if self.perturb_interpretable_space: | |
| interpretable_inps.append(curr_sample) | |
| curr_model_inputs.append( | |
| self.from_interp_rep_transform( # type: ignore | |
| curr_sample, inputs, **kwargs | |
| ) | |
| ) | |
| else: | |
| curr_model_inputs.append(curr_sample) | |
| interpretable_inps.append( | |
| self.to_interp_rep_transform( # type: ignore | |
| curr_sample, inputs, **kwargs | |
| ) | |
| ) | |
| curr_sim = self.similarity_func( | |
| inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs | |
| ) | |
| similarities.append( | |
| curr_sim.flatten() | |
| if isinstance(curr_sim, Tensor) | |
| else torch.tensor([curr_sim], device=device) | |
| ) | |
| if len(curr_model_inputs) == perturbations_per_eval: | |
| if expanded_additional_args is None: | |
| expanded_additional_args = _expand_additional_forward_args( | |
| additional_forward_args, len(curr_model_inputs) | |
| ) | |
| if expanded_target is None: | |
| expanded_target = _expand_target(target, len(curr_model_inputs)) | |
| model_out = self._evaluate_batch( | |
| curr_model_inputs, | |
| expanded_target, | |
| expanded_additional_args, | |
| device, | |
| ) | |
| if show_progress: | |
| attr_progress.update() | |
| outputs.append(model_out) | |
| curr_model_inputs = [] | |
| if len(curr_model_inputs) > 0: | |
| expanded_additional_args = _expand_additional_forward_args( | |
| additional_forward_args, len(curr_model_inputs) | |
| ) | |
| expanded_target = _expand_target(target, len(curr_model_inputs)) | |
| model_out = self._evaluate_batch( | |
| curr_model_inputs, | |
| expanded_target, | |
| expanded_additional_args, | |
| device, | |
| ) | |
| if show_progress: | |
| attr_progress.update() | |
| outputs.append(model_out) | |
| if show_progress: | |
| attr_progress.close() | |
| combined_interp_inps = torch.cat(interpretable_inps).double() | |
| combined_outputs = ( | |
| torch.cat(outputs) | |
| if len(outputs[0].shape) > 0 | |
| else torch.stack(outputs) | |
| ).double() | |
| combined_sim = ( | |
| torch.cat(similarities) | |
| if len(similarities[0].shape) > 0 | |
| else torch.stack(similarities) | |
| ).double() | |
| dataset = TensorDataset( | |
| combined_interp_inps, combined_outputs, combined_sim | |
| ) | |
| self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) | |
| return self.interpretable_model.representation() | |
| def _evaluate_batch( | |
| self, | |
| curr_model_inputs: List[TensorOrTupleOfTensorsGeneric], | |
| expanded_target: TargetType, | |
| expanded_additional_args: Any, | |
| device: torch.device, | |
| ): | |
| model_out = _run_forward( | |
| self.forward_func, | |
| _reduce_list(curr_model_inputs), | |
| expanded_target, | |
| expanded_additional_args, | |
| ) | |
| if isinstance(model_out, Tensor): | |
| assert model_out.numel() == len(curr_model_inputs), ( | |
| "Number of outputs is not appropriate, must return " | |
| "one output per perturbed input" | |
| ) | |
| if isinstance(model_out, Tensor): | |
| return model_out.flatten() | |
| return torch.tensor([model_out], device=device) | |
| def has_convergence_delta(self) -> bool: | |
| return False | |
| def multiplies_by_inputs(self): | |
| return False | |
| # Default transformations and methods | |
| # for Lime child implementation. | |
| def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): | |
| assert ( | |
| "feature_mask" in kwargs | |
| ), "Must provide feature_mask to use default interpretable representation transform" | |
| assert ( | |
| "baselines" in kwargs | |
| ), "Must provide baselines to use default interpretable representation transfrom" | |
| feature_mask = kwargs["feature_mask"] | |
| if isinstance(feature_mask, Tensor): | |
| binary_mask = curr_sample[0][feature_mask].bool() | |
| return ( | |
| binary_mask.to(original_inputs.dtype) * original_inputs | |
| + (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"] | |
| ) | |
| else: | |
| binary_mask = tuple( | |
| curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask)) | |
| ) | |
| return tuple( | |
| binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j] | |
| + (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j] | |
| for j in range(len(feature_mask)) | |
| ) | |
| def get_exp_kernel_similarity_function( | |
| distance_mode: str = "cosine", kernel_width: float = 1.0 | |
| ) -> Callable: | |
| r""" | |
| This method constructs an appropriate similarity function to compute | |
| weights for perturbed sample in LIME. Distance between the original | |
| and perturbed inputs is computed based on the provided distance mode, | |
| and the distance is passed through an exponential kernel with given | |
| kernel width to convert to a range between 0 and 1. | |
| The callable returned can be provided as the similarity_fn for | |
| Lime or LimeBase. | |
| Args: | |
| distance_mode (str, optional): Distance mode can be either "cosine" or | |
| "euclidean" corresponding to either cosine distance | |
| or Euclidean distance respectively. Distance is computed | |
| by flattening the original inputs and perturbed inputs | |
| (concatenating tuples of inputs if necessary) and computing | |
| distances between the resulting vectors. | |
| Default: "cosine" | |
| kernel_width (float, optional): | |
| Kernel width for exponential kernel applied to distance. | |
| Default: 1.0 | |
| Returns: | |
| *Callable*: | |
| - **similarity_fn** (*Callable*): | |
| Similarity function. This callable can be provided as the | |
| similarity_fn for Lime or LimeBase. | |
| """ | |
| def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): | |
| flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float() | |
| flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float() | |
| if distance_mode == "cosine": | |
| cos_sim = CosineSimilarity(dim=0) | |
| distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp) | |
| elif distance_mode == "euclidean": | |
| distance = torch.norm(flattened_original_inp - flattened_perturbed_inp) | |
| else: | |
| raise ValueError("distance_mode must be either cosine or euclidean.") | |
| return math.exp(-1 * (distance ** 2) / (2 * (kernel_width ** 2))) | |
| return default_exp_kernel | |
| def default_perturb_func(original_inp, **kwargs): | |
| assert ( | |
| "num_interp_features" in kwargs | |
| ), "Must provide num_interp_features to use default interpretable sampling function" | |
| if isinstance(original_inp, Tensor): | |
| device = original_inp.device | |
| else: | |
| device = original_inp[0].device | |
| probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5 | |
| return torch.bernoulli(probs).to(device=device).long() | |
| def construct_feature_mask(feature_mask, formatted_inputs): | |
| if feature_mask is None: | |
| feature_mask, num_interp_features = _construct_default_feature_mask( | |
| formatted_inputs | |
| ) | |
| else: | |
| feature_mask = _format_tensor_into_tuples(feature_mask) | |
| min_interp_features = int( | |
| min( | |
| torch.min(single_mask).item() | |
| for single_mask in feature_mask | |
| if single_mask.numel() | |
| ) | |
| ) | |
| if min_interp_features != 0: | |
| warnings.warn( | |
| "Minimum element in feature mask is not 0, shifting indices to" | |
| " start at 0." | |
| ) | |
| feature_mask = tuple( | |
| single_mask - min_interp_features for single_mask in feature_mask | |
| ) | |
| num_interp_features = int( | |
| max( | |
| torch.max(single_mask).item() | |
| for single_mask in feature_mask | |
| if single_mask.numel() | |
| ) | |
| + 1 | |
| ) | |
| return feature_mask, num_interp_features | |
| class Lime(LimeBase): | |
| r""" | |
| Lime is an interpretability method that trains an interpretable surrogate model | |
| by sampling points around a specified input example and using model evaluations | |
| at these points to train a simpler interpretable 'surrogate' model, such as a | |
| linear model. | |
| Lime provides a more specific implementation than LimeBase in order to expose | |
| a consistent API with other perturbation-based algorithms. For more general | |
| use of the LIME framework, consider using the LimeBase class directly and | |
| defining custom sampling and transformation to / from interpretable | |
| representation functions. | |
| Lime assumes that the interpretable representation is a binary vector, | |
| corresponding to some elements in the input being set to their baseline value | |
| if the corresponding binary interpretable feature value is 0 or being set | |
| to the original input value if the corresponding binary interpretable | |
| feature value is 1. Input values can be grouped to correspond to the same | |
| binary interpretable feature using a feature mask provided when calling | |
| attribute, similar to other perturbation-based attribution methods. | |
| One example of this setting is when applying Lime to an image classifier. | |
| Pixels in an image can be grouped into super-pixels or segments, which | |
| correspond to interpretable features, provided as a feature_mask when | |
| calling attribute. Sampled binary vectors convey whether a super-pixel | |
| is on (retains the original input values) or off (set to the corresponding | |
| baseline value, e.g. black image). An interpretable linear model is trained | |
| with input being the binary vectors and outputs as the corresponding scores | |
| of the image classifier with the appropriate super-pixels masked based on the | |
| binary vector. Coefficients of the trained surrogate | |
| linear model convey the importance of each super-pixel. | |
| More details regarding LIME can be found in the original paper: | |
| https://arxiv.org/abs/1602.04938 | |
| """ | |
| def __init__( | |
| self, | |
| forward_func: Callable, | |
| interpretable_model: Optional[Model] = None, | |
| similarity_func: Optional[Callable] = None, | |
| perturb_func: Optional[Callable] = None, | |
| ) -> None: | |
| r""" | |
| Args: | |
| forward_func (callable): The forward function of the model or any | |
| modification of it | |
| interpretable_model (optional, Model): Model object to train | |
| interpretable model. | |
| This argument is optional and defaults to SkLearnLasso(alpha=0.01), | |
| which is a wrapper around the Lasso linear model in SkLearn. | |
| This requires having sklearn version >= 0.23 available. | |
| Other predefined interpretable linear models are provided in | |
| captum._utils.models.linear_model. | |
| Alternatively, a custom model object must provide a `fit` method to | |
| train the model, given a dataloader, with batches containing | |
| three tensors: | |
| - interpretable_inputs: Tensor | |
| [2D num_samples x num_interp_features], | |
| - expected_outputs: Tensor [1D num_samples], | |
| - weights: Tensor [1D num_samples] | |
| The model object must also provide a `representation` method to | |
| access the appropriate coefficients or representation of the | |
| interpretable model after fitting. | |
| Note that calling fit multiple times should retrain the | |
| interpretable model, each attribution call reuses | |
| the same given interpretable model object. | |
| similarity_func (optional, callable): Function which takes a single sample | |
| along with its corresponding interpretable representation | |
| and returns the weight of the interpretable sample for | |
| training the interpretable model. | |
| This is often referred to as a similarity kernel. | |
| This argument is optional and defaults to a function which | |
| applies an exponential kernel to the consine distance between | |
| the original input and perturbed input, with a kernel width | |
| of 1.0. | |
| A similarity function applying an exponential | |
| kernel to cosine / euclidean distances can be constructed | |
| using the provided get_exp_kernel_similarity_function in | |
| captum.attr._core.lime. | |
| Alternately, a custom callable can also be provided. | |
| The expected signature of this callable is: | |
| >>> def similarity_func( | |
| >>> original_input: Tensor or tuple of Tensors, | |
| >>> perturbed_input: Tensor or tuple of Tensors, | |
| >>> perturbed_interpretable_input: | |
| >>> Tensor [2D 1 x num_interp_features], | |
| >>> **kwargs: Any | |
| >>> ) -> float or Tensor containing float scalar | |
| perturbed_input and original_input will be the same type and | |
| contain tensors of the same shape, with original_input | |
| being the same as the input provided when calling attribute. | |
| kwargs includes baselines, feature_mask, num_interp_features | |
| (integer, determined from feature mask). | |
| perturb_func (optional, callable): Function which returns a single | |
| sampled input, which is a binary vector of length | |
| num_interp_features, or a generator of such tensors. | |
| This function is optional, the default function returns | |
| a binary vector where each element is selected | |
| independently and uniformly at random. Custom | |
| logic for selecting sampled binary vectors can | |
| be implemented by providing a function with the | |
| following expected signature: | |
| >>> perturb_func( | |
| >>> original_input: Tensor or tuple of Tensors, | |
| >>> **kwargs: Any | |
| >>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features] | |
| >>> or generator yielding such tensors | |
| kwargs includes baselines, feature_mask, num_interp_features | |
| (integer, determined from feature mask). | |
| """ | |
| if interpretable_model is None: | |
| interpretable_model = SkLearnLasso(alpha=0.01) | |
| if similarity_func is None: | |
| similarity_func = get_exp_kernel_similarity_function() | |
| if perturb_func is None: | |
| perturb_func = default_perturb_func | |
| LimeBase.__init__( | |
| self, | |
| forward_func, | |
| interpretable_model, | |
| similarity_func, | |
| perturb_func, | |
| True, | |
| default_from_interp_rep_transform, | |
| None, | |
| ) | |
| def attribute( # type: ignore | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| baselines: BaselineType = None, | |
| target: TargetType = None, | |
| additional_forward_args: Any = None, | |
| feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, | |
| n_samples: int = 50, | |
| perturbations_per_eval: int = 1, | |
| return_input_shape: bool = True, | |
| show_progress: bool = False, | |
| ) -> TensorOrTupleOfTensorsGeneric: | |
| r""" | |
| This method attributes the output of the model with given target index | |
| (in case it is provided, otherwise it assumes that output is a | |
| scalar) to the inputs of the model using the approach described above, | |
| training an interpretable model and returning a representation of the | |
| interpretable model. | |
| It is recommended to only provide a single example as input (tensors | |
| with first dimension or batch size = 1). This is because LIME is generally | |
| used for sample-based interpretability, training a separate interpretable | |
| model to explain a model's prediction on each individual example. | |
| A batch of inputs can also be provided as inputs, similar to | |
| other perturbation-based attribution methods. In this case, if forward_fn | |
| returns a scalar per example, attributions will be computed for each | |
| example independently, with a separate interpretable model trained for each | |
| example. Note that provided similarity and perturbation functions will be | |
| provided each example separately (first dimension = 1) in this case. | |
| If forward_fn returns a scalar per batch (e.g. loss), attributions will | |
| still be computed using a single interpretable model for the full batch. | |
| In this case, similarity and perturbation functions will be provided the | |
| same original input containing the full batch. | |
| The number of interpretable features is determined from the provided | |
| feature mask, or if none is provided, from the default feature mask, | |
| which considers each scalar input as a separate feature. It is | |
| generally recommended to provide a feature mask which groups features | |
| into a small number of interpretable features / components (e.g. | |
| superpixels in images). | |
| Args: | |
| inputs (tensor or tuple of tensors): Input for which LIME | |
| is computed. If forward_func takes a single | |
| tensor as input, a single input tensor should be provided. | |
| If forward_func takes multiple tensors as input, a tuple | |
| of the input tensors should be provided. It is assumed | |
| that for all given input tensors, dimension 0 corresponds | |
| to the number of examples, and if multiple input tensors | |
| are provided, the examples must be aligned appropriately. | |
| baselines (scalar, tensor, tuple of scalars or tensors, optional): | |
| Baselines define reference value which replaces each | |
| feature when the corresponding interpretable feature | |
| is set to 0. | |
| Baselines can be provided as: | |
| - a single tensor, if inputs is a single tensor, with | |
| exactly the same dimensions as inputs or the first | |
| dimension is one and the remaining dimensions match | |
| with inputs. | |
| - a single scalar, if inputs is a single tensor, which will | |
| be broadcasted for each input value in input tensor. | |
| - a tuple of tensors or scalars, the baseline corresponding | |
| to each tensor in the inputs' tuple can be: | |
| - either a tensor with matching dimensions to | |
| corresponding tensor in the inputs' tuple | |
| or the first dimension is one and the remaining | |
| dimensions match with the corresponding | |
| input tensor. | |
| - or a scalar, corresponding to a tensor in the | |
| inputs' tuple. This scalar value is broadcasted | |
| for corresponding input tensor. | |
| In the cases when `baselines` is not provided, we internally | |
| use zero scalar corresponding to each input tensor. | |
| Default: None | |
| target (int, tuple, tensor or list, optional): Output indices for | |
| which surrogate model is trained | |
| (for classification cases, | |
| this is usually the target class). | |
| If the network returns a scalar value per example, | |
| no target index is necessary. | |
| For general 2D outputs, targets can be either: | |
| - a single integer or a tensor containing a single | |
| integer, which is applied to all input examples | |
| - a list of integers or a 1D tensor, with length matching | |
| the number of examples in inputs (dim 0). Each integer | |
| is applied as the target for the corresponding example. | |
| For outputs with > 2 dimensions, targets can be either: | |
| - A single tuple, which contains #output_dims - 1 | |
| elements. This target index is applied to all examples. | |
| - A list of tuples with length equal to the number of | |
| examples in inputs (dim 0), and each tuple containing | |
| #output_dims - 1 elements. Each tuple is applied as the | |
| target for the corresponding example. | |
| Default: None | |
| additional_forward_args (any, optional): If the forward function | |
| requires additional arguments other than the inputs for | |
| which attributions should not be computed, this argument | |
| can be provided. It must be either a single additional | |
| argument of a Tensor or arbitrary (non-tuple) type or a | |
| tuple containing multiple additional arguments including | |
| tensors or any arbitrary python types. These arguments | |
| are provided to forward_func in order following the | |
| arguments in inputs. | |
| For a tensor, the first dimension of the tensor must | |
| correspond to the number of examples. It will be | |
| repeated for each of `n_steps` along the integrated | |
| path. For all other types, the given argument is used | |
| for all forward evaluations. | |
| Note that attributions are not computed with respect | |
| to these arguments. | |
| Default: None | |
| feature_mask (tensor or tuple of tensors, optional): | |
| feature_mask defines a mask for the input, grouping | |
| features which correspond to the same | |
| interpretable feature. feature_mask | |
| should contain the same number of tensors as inputs. | |
| Each tensor should | |
| be the same size as the corresponding input or | |
| broadcastable to match the input tensor. Values across | |
| all tensors should be integers in the range 0 to | |
| num_interp_features - 1, and indices corresponding to the | |
| same feature should have the same value. | |
| Note that features are grouped across tensors | |
| (unlike feature ablation and occlusion), so | |
| if the same index is used in different tensors, those | |
| features are still grouped and added simultaneously. | |
| If None, then a feature mask is constructed which assigns | |
| each scalar within a tensor as a separate feature. | |
| Default: None | |
| n_samples (int, optional): The number of samples of the original | |
| model used to train the surrogate interpretable model. | |
| Default: `50` if `n_samples` is not provided. | |
| perturbations_per_eval (int, optional): Allows multiple samples | |
| to be processed simultaneously in one call to forward_fn. | |
| Each forward pass will contain a maximum of | |
| perturbations_per_eval * #examples samples. | |
| For DataParallel models, each batch is split among the | |
| available devices, so evaluations on each available | |
| device contain at most | |
| (perturbations_per_eval * #examples) / num_devices | |
| samples. | |
| If the forward function returns a single scalar per batch, | |
| perturbations_per_eval must be set to 1. | |
| Default: 1 | |
| return_input_shape (bool, optional): Determines whether the returned | |
| tensor(s) only contain the coefficients for each interp- | |
| retable feature from the trained surrogate model, or | |
| whether the returned attributions match the input shape. | |
| When return_input_shape is True, the return type of attribute | |
| matches the input shape, with each element containing the | |
| coefficient of the corresponding interpretale feature. | |
| All elements with the same value in the feature mask | |
| will contain the same coefficient in the returned | |
| attributions. If return_input_shape is False, a 1D | |
| tensor is returned, containing only the coefficients | |
| of the trained interpreatable models, with length | |
| num_interp_features. | |
| show_progress (bool, optional): Displays the progress of computation. | |
| It will try to use tqdm if available for advanced features | |
| (e.g. time estimation). Otherwise, it will fallback to | |
| a simple output of progress. | |
| Default: False | |
| Returns: | |
| *tensor* or tuple of *tensors* of **attributions**: | |
| - **attributions** (*tensor* or tuple of *tensors*): | |
| The attributions with respect to each input feature. | |
| If return_input_shape = True, attributions will be | |
| the same size as the provided inputs, with each value | |
| providing the coefficient of the corresponding | |
| interpretale feature. | |
| If return_input_shape is False, a 1D | |
| tensor is returned, containing only the coefficients | |
| of the trained interpreatable models, with length | |
| num_interp_features. | |
| Examples:: | |
| >>> # SimpleClassifier takes a single input tensor of size Nx4x4, | |
| >>> # and returns an Nx3 tensor of class probabilities. | |
| >>> net = SimpleClassifier() | |
| >>> # Generating random input with size 1 x 4 x 4 | |
| >>> input = torch.randn(1, 4, 4) | |
| >>> # Defining Lime interpreter | |
| >>> lime = Lime(net) | |
| >>> # Computes attribution, with each of the 4 x 4 = 16 | |
| >>> # features as a separate interpretable feature | |
| >>> attr = lime.attribute(input, target=1, n_samples=200) | |
| >>> # Alternatively, we can group each 2x2 square of the inputs | |
| >>> # as one 'interpretable' feature and perturb them together. | |
| >>> # This can be done by creating a feature mask as follows, which | |
| >>> # defines the feature groups, e.g.: | |
| >>> # +---+---+---+---+ | |
| >>> # | 0 | 0 | 1 | 1 | | |
| >>> # +---+---+---+---+ | |
| >>> # | 0 | 0 | 1 | 1 | | |
| >>> # +---+---+---+---+ | |
| >>> # | 2 | 2 | 3 | 3 | | |
| >>> # +---+---+---+---+ | |
| >>> # | 2 | 2 | 3 | 3 | | |
| >>> # +---+---+---+---+ | |
| >>> # With this mask, all inputs with the same value are set to their | |
| >>> # baseline value, when the corresponding binary interpretable | |
| >>> # feature is set to 0. | |
| >>> # The attributions can be calculated as follows: | |
| >>> # feature mask has dimensions 1 x 4 x 4 | |
| >>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], | |
| >>> [2,2,3,3],[2,2,3,3]]]) | |
| >>> # Computes interpretable model and returning attributions | |
| >>> # matching input shape. | |
| >>> attr = lime.attribute(input, target=1, feature_mask=feature_mask) | |
| """ | |
| return self._attribute_kwargs( | |
| inputs=inputs, | |
| baselines=baselines, | |
| target=target, | |
| additional_forward_args=additional_forward_args, | |
| feature_mask=feature_mask, | |
| n_samples=n_samples, | |
| perturbations_per_eval=perturbations_per_eval, | |
| return_input_shape=return_input_shape, | |
| show_progress=show_progress, | |
| ) | |
| def _attribute_kwargs( # type: ignore | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| baselines: BaselineType = None, | |
| target: TargetType = None, | |
| additional_forward_args: Any = None, | |
| feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, | |
| n_samples: int = 25, | |
| perturbations_per_eval: int = 1, | |
| return_input_shape: bool = True, | |
| show_progress: bool = False, | |
| **kwargs, | |
| ) -> TensorOrTupleOfTensorsGeneric: | |
| is_inputs_tuple = _is_tuple(inputs) | |
| formatted_inputs, baselines = _format_input_baseline(inputs, baselines) | |
| bsz = formatted_inputs[0].shape[0] | |
| feature_mask, num_interp_features = construct_feature_mask( | |
| feature_mask, formatted_inputs | |
| ) | |
| if num_interp_features > 10000: | |
| warnings.warn( | |
| "Attempting to construct interpretable model with > 10000 features." | |
| "This can be very slow or lead to OOM issues. Please provide a feature" | |
| "mask which groups input features to reduce the number of interpretable" | |
| "features. " | |
| ) | |
| coefs: Tensor | |
| if bsz > 1: | |
| test_output = _run_forward( | |
| self.forward_func, inputs, target, additional_forward_args | |
| ) | |
| if isinstance(test_output, Tensor) and torch.numel(test_output) > 1: | |
| if torch.numel(test_output) == bsz: | |
| warnings.warn( | |
| "You are providing multiple inputs for Lime / Kernel SHAP " | |
| "attributions. This trains a separate interpretable model " | |
| "for each example, which can be time consuming. It is " | |
| "recommended to compute attributions for one example at a time." | |
| ) | |
| output_list = [] | |
| for ( | |
| curr_inps, | |
| curr_target, | |
| curr_additional_args, | |
| curr_baselines, | |
| curr_feature_mask, | |
| ) in _batch_example_iterator( | |
| bsz, | |
| formatted_inputs, | |
| target, | |
| additional_forward_args, | |
| baselines, | |
| feature_mask, | |
| ): | |
| coefs = super().attribute.__wrapped__( | |
| self, | |
| inputs=curr_inps if is_inputs_tuple else curr_inps[0], | |
| target=curr_target, | |
| additional_forward_args=curr_additional_args, | |
| n_samples=n_samples, | |
| perturbations_per_eval=perturbations_per_eval, | |
| baselines=curr_baselines | |
| if is_inputs_tuple | |
| else curr_baselines[0], | |
| feature_mask=curr_feature_mask | |
| if is_inputs_tuple | |
| else curr_feature_mask[0], | |
| num_interp_features=num_interp_features, | |
| show_progress=show_progress, | |
| **kwargs, | |
| ) | |
| if return_input_shape: | |
| output_list.append( | |
| self._convert_output_shape( | |
| curr_inps, | |
| curr_feature_mask, | |
| coefs, | |
| num_interp_features, | |
| is_inputs_tuple, | |
| ) | |
| ) | |
| else: | |
| output_list.append(coefs.reshape(1, -1)) # type: ignore | |
| return _reduce_list(output_list) | |
| else: | |
| raise AssertionError( | |
| "Invalid number of outputs, forward function should return a" | |
| "scalar per example or a scalar per input batch." | |
| ) | |
| else: | |
| assert perturbations_per_eval == 1, ( | |
| "Perturbations per eval must be 1 when forward function" | |
| "returns single value per batch!" | |
| ) | |
| coefs = super().attribute.__wrapped__( | |
| self, | |
| inputs=inputs, | |
| target=target, | |
| additional_forward_args=additional_forward_args, | |
| n_samples=n_samples, | |
| perturbations_per_eval=perturbations_per_eval, | |
| baselines=baselines if is_inputs_tuple else baselines[0], | |
| feature_mask=feature_mask if is_inputs_tuple else feature_mask[0], | |
| num_interp_features=num_interp_features, | |
| show_progress=show_progress, | |
| **kwargs, | |
| ) | |
| if return_input_shape: | |
| return self._convert_output_shape( | |
| formatted_inputs, | |
| feature_mask, | |
| coefs, | |
| num_interp_features, | |
| is_inputs_tuple, | |
| ) | |
| else: | |
| return coefs | |
| def _convert_output_shape( | |
| self, | |
| formatted_inp: Tuple[Tensor, ...], | |
| feature_mask: Tuple[Tensor, ...], | |
| coefs: Tensor, | |
| num_interp_features: int, | |
| is_inputs_tuple: Literal[True], | |
| ) -> Tuple[Tensor, ...]: | |
| ... | |
| def _convert_output_shape( | |
| self, | |
| formatted_inp: Tuple[Tensor, ...], | |
| feature_mask: Tuple[Tensor, ...], | |
| coefs: Tensor, | |
| num_interp_features: int, | |
| is_inputs_tuple: Literal[False], | |
| ) -> Tensor: | |
| ... | |
| def _convert_output_shape( | |
| self, | |
| formatted_inp: Tuple[Tensor, ...], | |
| feature_mask: Tuple[Tensor, ...], | |
| coefs: Tensor, | |
| num_interp_features: int, | |
| is_inputs_tuple: bool, | |
| ) -> Union[Tensor, Tuple[Tensor, ...]]: | |
| coefs = coefs.flatten() | |
| attr = [ | |
| torch.zeros_like(single_inp, dtype=torch.float) | |
| for single_inp in formatted_inp | |
| ] | |
| for tensor_ind in range(len(formatted_inp)): | |
| for single_feature in range(num_interp_features): | |
| attr[tensor_ind] += ( | |
| coefs[single_feature].item() | |
| * (feature_mask[tensor_ind] == single_feature).float() | |
| ) | |
| return _format_output(is_inputs_tuple, tuple(attr)) | |