Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import typing | |
| from collections import defaultdict | |
| from typing import Any, cast, List, Tuple, Union | |
| import torch.nn as nn | |
| from captum._utils.common import ( | |
| _format_output, | |
| _format_tensor_into_tuples, | |
| _is_tuple, | |
| _register_backward_hook, | |
| _run_forward, | |
| ) | |
| from captum._utils.gradient import ( | |
| apply_gradient_requirements, | |
| undo_gradient_requirements, | |
| ) | |
| from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric | |
| from captum.attr._utils.attribution import GradientAttribution | |
| from captum.attr._utils.common import _sum_rows | |
| from captum.attr._utils.custom_modules import Addition_Module | |
| from captum.attr._utils.lrp_rules import EpsilonRule, PropagationRule | |
| from captum.log import log_usage | |
| from torch import Tensor | |
| from torch.nn import Module | |
| from torch.utils.hooks import RemovableHandle | |
| class LRP(GradientAttribution): | |
| r""" | |
| Layer-wise relevance propagation is based on a backward propagation | |
| mechanism applied sequentially to all layers of the model. Here, the | |
| model output score represents the initial relevance which is decomposed | |
| into values for each neuron of the underlying layers. The decomposition | |
| is defined by rules that are chosen for each layer, involving its weights | |
| and activations. Details on the model can be found in the original paper | |
| [https://doi.org/10.1371/journal.pone.0130140]. The implementation is | |
| inspired by the tutorial of the same group | |
| [https://doi.org/10.1016/j.dsp.2017.10.011] and the publication by | |
| Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW]. | |
| """ | |
| def __init__(self, model: Module) -> None: | |
| r""" | |
| Args: | |
| model (module): The forward function of the model or any modification of | |
| it. Custom rules for a given layer need to be defined as attribute | |
| `module.rule` and need to be of type PropagationRule. If no rule is | |
| specified for a layer, a pre-defined default rule for the module type | |
| is used. Model cannot contain any in-place nonlinear submodules; | |
| these are not supported by the register_full_backward_hook | |
| PyTorch API starting from PyTorch v1.9. | |
| """ | |
| GradientAttribution.__init__(self, model) | |
| self.model = model | |
| self._check_rules() | |
| def multiplies_by_inputs(self) -> bool: | |
| return True | |
| def attribute( | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| target: TargetType = None, | |
| additional_forward_args: Any = None, | |
| return_convergence_delta: Literal[False] = False, | |
| verbose: bool = False, | |
| ) -> TensorOrTupleOfTensorsGeneric: | |
| ... | |
| def attribute( | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| target: TargetType = None, | |
| additional_forward_args: Any = None, | |
| *, | |
| return_convergence_delta: Literal[True], | |
| verbose: bool = False, | |
| ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: | |
| ... | |
| def attribute( | |
| self, | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| target: TargetType = None, | |
| additional_forward_args: Any = None, | |
| return_convergence_delta: bool = False, | |
| verbose: bool = False, | |
| ) -> Union[ | |
| TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] | |
| ]: | |
| r""" | |
| Args: | |
| inputs (tensor or tuple of tensors): Input for which relevance is | |
| propagated. If forward_func takes a single | |
| tensor as input, a single input tensor should be provided. | |
| If forward_func takes multiple tensors as input, a tuple | |
| of the input tensors should be provided. It is assumed | |
| that for all given input tensors, dimension 0 corresponds | |
| to the number of examples, and if multiple input tensors | |
| are provided, the examples must be aligned appropriately. | |
| target (int, tuple, tensor or list, optional): Output indices for | |
| which gradients are computed (for classification cases, | |
| this is usually the target class). | |
| If the network returns a scalar value per example, | |
| no target index is necessary. | |
| For general 2D outputs, targets can be either: | |
| - a single integer or a tensor containing a single | |
| integer, which is applied to all input examples | |
| - a list of integers or a 1D tensor, with length matching | |
| the number of examples in inputs (dim 0). Each integer | |
| is applied as the target for the corresponding example. | |
| For outputs with > 2 dimensions, targets can be either: | |
| - A single tuple, which contains #output_dims - 1 | |
| elements. This target index is applied to all examples. | |
| - A list of tuples with length equal to the number of | |
| examples in inputs (dim 0), and each tuple containing | |
| #output_dims - 1 elements. Each tuple is applied as the | |
| target for the corresponding example. | |
| Default: None | |
| additional_forward_args (tuple, optional): If the forward function | |
| requires additional arguments other than the inputs for | |
| which attributions should not be computed, this argument | |
| can be provided. It must be either a single additional | |
| argument of a Tensor or arbitrary (non-tuple) type or a tuple | |
| containing multiple additional arguments including tensors | |
| or any arbitrary python types. These arguments are provided to | |
| forward_func in order, following the arguments in inputs. | |
| Note that attributions are not computed with respect | |
| to these arguments. | |
| Default: None | |
| return_convergence_delta (bool, optional): Indicates whether to return | |
| convergence delta or not. If `return_convergence_delta` | |
| is set to True convergence delta will be returned in | |
| a tuple following attributions. | |
| Default: False | |
| verbose (bool, optional): Indicates whether information on application | |
| of rules is printed during propagation. | |
| Returns: | |
| *tensor* or tuple of *tensors* of **attributions** | |
| or 2-element tuple of **attributions**, **delta**:: | |
| - **attributions** (*tensor* or tuple of *tensors*): | |
| The propagated relevance values with respect to each | |
| input feature. The values are normalized by the output score | |
| value (sum(relevance)=1). To obtain values comparable to other | |
| methods or implementations these values need to be multiplied | |
| by the output score. Attributions will always | |
| be the same size as the provided inputs, with each value | |
| providing the attribution of the corresponding input index. | |
| If a single tensor is provided as inputs, a single tensor is | |
| returned. If a tuple is provided for inputs, a tuple of | |
| corresponding sized tensors is returned. The sum of attributions | |
| is one and not corresponding to the prediction score as in other | |
| implementations. | |
| - **delta** (*tensor*, returned if return_convergence_delta=True): | |
| Delta is calculated per example, meaning that the number of | |
| elements in returned delta tensor is equal to the number of | |
| of examples in the inputs. | |
| Examples:: | |
| >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, | |
| >>> # and returns an Nx10 tensor of class probabilities. It has one | |
| >>> # Conv2D and a ReLU layer. | |
| >>> net = ImageClassifier() | |
| >>> lrp = LRP(net) | |
| >>> input = torch.randn(3, 3, 32, 32) | |
| >>> # Attribution size matches input size: 3x3x32x32 | |
| >>> attribution = lrp.attribute(input, target=5) | |
| """ | |
| self.verbose = verbose | |
| self._original_state_dict = self.model.state_dict() | |
| self.layers: List[Module] = [] | |
| self._get_layers(self.model) | |
| self._check_and_attach_rules() | |
| self.backward_handles: List[RemovableHandle] = [] | |
| self.forward_handles: List[RemovableHandle] = [] | |
| is_inputs_tuple = _is_tuple(inputs) | |
| inputs = _format_tensor_into_tuples(inputs) | |
| gradient_mask = apply_gradient_requirements(inputs) | |
| try: | |
| # 1. Forward pass: Change weights of layers according to selected rules. | |
| output = self._compute_output_and_change_weights( | |
| inputs, target, additional_forward_args | |
| ) | |
| # 2. Forward pass + backward pass: Register hooks to configure relevance | |
| # propagation and execute back-propagation. | |
| self._register_forward_hooks() | |
| normalized_relevances = self.gradient_func( | |
| self._forward_fn_wrapper, inputs, target, additional_forward_args | |
| ) | |
| relevances = tuple( | |
| normalized_relevance | |
| * output.reshape((-1,) + (1,) * (normalized_relevance.dim() - 1)) | |
| for normalized_relevance in normalized_relevances | |
| ) | |
| finally: | |
| self._restore_model() | |
| undo_gradient_requirements(inputs, gradient_mask) | |
| if return_convergence_delta: | |
| return ( | |
| _format_output(is_inputs_tuple, relevances), | |
| self.compute_convergence_delta(relevances, output), | |
| ) | |
| else: | |
| return _format_output(is_inputs_tuple, relevances) # type: ignore | |
| def has_convergence_delta(self) -> bool: | |
| return True | |
| def compute_convergence_delta( | |
| self, attributions: Union[Tensor, Tuple[Tensor, ...]], output: Tensor | |
| ) -> Tensor: | |
| """ | |
| Here, we use the completeness property of LRP: The relevance is conserved | |
| during the propagation through the models' layers. Therefore, the difference | |
| between the sum of attribution (relevance) values and model output is taken as | |
| the convergence delta. It should be zero for functional attribution. However, | |
| when rules with an epsilon value are used for stability reasons, relevance is | |
| absorbed during propagation and the convergence delta is non-zero. | |
| Args: | |
| attributions (tensor or tuple of tensors): Attribution scores that | |
| are precomputed by an attribution algorithm. | |
| Attributions can be provided in form of a single tensor | |
| or a tuple of those. It is assumed that attribution | |
| tensor's dimension 0 corresponds to the number of | |
| examples, and if multiple input tensors are provided, | |
| the examples must be aligned appropriately. | |
| output (tensor with single element): The output value with respect to which | |
| the attribution values are computed. This value corresponds to | |
| the target score of a classification model. | |
| Returns: | |
| *tensor*: | |
| - **delta** Difference of relevance in output layer and input layer. | |
| """ | |
| if isinstance(attributions, tuple): | |
| for attr in attributions: | |
| summed_attr = cast( | |
| Tensor, sum(_sum_rows(attr) for attr in attributions) | |
| ) | |
| else: | |
| summed_attr = _sum_rows(attributions) | |
| return output.flatten() - summed_attr.flatten() | |
| def _get_layers(self, model: Module) -> None: | |
| for layer in model.children(): | |
| if len(list(layer.children())) == 0: | |
| self.layers.append(layer) | |
| else: | |
| self._get_layers(layer) | |
| def _check_and_attach_rules(self) -> None: | |
| for layer in self.layers: | |
| if hasattr(layer, "rule"): | |
| layer.activations = {} # type: ignore | |
| layer.rule.relevance_input = defaultdict(list) # type: ignore | |
| layer.rule.relevance_output = {} # type: ignore | |
| pass | |
| elif type(layer) in SUPPORTED_LAYERS_WITH_RULES.keys(): | |
| layer.activations = {} # type: ignore | |
| layer.rule = SUPPORTED_LAYERS_WITH_RULES[type(layer)]() # type: ignore | |
| layer.rule.relevance_input = defaultdict(list) # type: ignore | |
| layer.rule.relevance_output = {} # type: ignore | |
| elif type(layer) in SUPPORTED_NON_LINEAR_LAYERS: | |
| layer.rule = None # type: ignore | |
| else: | |
| raise TypeError( | |
| ( | |
| f"Module of type {type(layer)} has no rule defined and no" | |
| "default rule exists for this module type. Please, set a rule" | |
| "explicitly for this module and assure that it is appropriate" | |
| "for this type of layer." | |
| ) | |
| ) | |
| def _check_rules(self) -> None: | |
| for module in self.model.modules(): | |
| if hasattr(module, "rule"): | |
| if ( | |
| not isinstance(module.rule, PropagationRule) | |
| and module.rule is not None | |
| ): | |
| raise TypeError( | |
| ( | |
| f"Please select propagation rules inherited from class " | |
| f"PropagationRule for module: {module}" | |
| ) | |
| ) | |
| def _register_forward_hooks(self) -> None: | |
| for layer in self.layers: | |
| if type(layer) in SUPPORTED_NON_LINEAR_LAYERS: | |
| backward_handle = _register_backward_hook( | |
| layer, PropagationRule.backward_hook_activation, self | |
| ) | |
| self.backward_handles.append(backward_handle) | |
| else: | |
| forward_handle = layer.register_forward_hook( | |
| layer.rule.forward_hook # type: ignore | |
| ) | |
| self.forward_handles.append(forward_handle) | |
| if self.verbose: | |
| print(f"Applied {layer.rule} on layer {layer}") | |
| def _register_weight_hooks(self) -> None: | |
| for layer in self.layers: | |
| if layer.rule is not None: | |
| forward_handle = layer.register_forward_hook( | |
| layer.rule.forward_hook_weights # type: ignore | |
| ) | |
| self.forward_handles.append(forward_handle) | |
| def _register_pre_hooks(self) -> None: | |
| for layer in self.layers: | |
| if layer.rule is not None: | |
| forward_handle = layer.register_forward_pre_hook( | |
| layer.rule.forward_pre_hook_activations # type: ignore | |
| ) | |
| self.forward_handles.append(forward_handle) | |
| def _compute_output_and_change_weights( | |
| self, | |
| inputs: Tuple[Tensor, ...], | |
| target: TargetType, | |
| additional_forward_args: Any, | |
| ) -> Tensor: | |
| try: | |
| self._register_weight_hooks() | |
| output = _run_forward(self.model, inputs, target, additional_forward_args) | |
| finally: | |
| self._remove_forward_hooks() | |
| # Register pre_hooks that pass the initial activations from before weight | |
| # adjustments as inputs to the layers with adjusted weights. This procedure | |
| # is important for graph generation in the 2nd forward pass. | |
| self._register_pre_hooks() | |
| return output | |
| def _remove_forward_hooks(self) -> None: | |
| for forward_handle in self.forward_handles: | |
| forward_handle.remove() | |
| def _remove_backward_hooks(self) -> None: | |
| for backward_handle in self.backward_handles: | |
| backward_handle.remove() | |
| for layer in self.layers: | |
| if hasattr(layer.rule, "_handle_input_hooks"): | |
| for handle in layer.rule._handle_input_hooks: # type: ignore | |
| handle.remove() | |
| if hasattr(layer.rule, "_handle_output_hook"): | |
| layer.rule._handle_output_hook.remove() # type: ignore | |
| def _remove_rules(self) -> None: | |
| for layer in self.layers: | |
| if hasattr(layer, "rule"): | |
| del layer.rule | |
| def _clear_properties(self) -> None: | |
| for layer in self.layers: | |
| if hasattr(layer, "activation"): | |
| del layer.activation | |
| def _restore_state(self) -> None: | |
| self.model.load_state_dict(self._original_state_dict) # type: ignore | |
| def _restore_model(self) -> None: | |
| self._restore_state() | |
| self._remove_backward_hooks() | |
| self._remove_forward_hooks() | |
| self._remove_rules() | |
| self._clear_properties() | |
| def _forward_fn_wrapper(self, *inputs: Tensor) -> Tensor: | |
| """ | |
| Wraps a forward function with addition of zero as a workaround to | |
| https://github.com/pytorch/pytorch/issues/35802 discussed in | |
| https://github.com/pytorch/captum/issues/143#issuecomment-611750044 | |
| #TODO: Remove when bugs are fixed | |
| """ | |
| adjusted_inputs = tuple( | |
| input + 0 if input is not None else input for input in inputs | |
| ) | |
| return self.model(*adjusted_inputs) | |
| SUPPORTED_LAYERS_WITH_RULES = { | |
| nn.MaxPool1d: EpsilonRule, | |
| nn.MaxPool2d: EpsilonRule, | |
| nn.MaxPool3d: EpsilonRule, | |
| nn.Conv2d: EpsilonRule, | |
| nn.AvgPool2d: EpsilonRule, | |
| nn.AdaptiveAvgPool2d: EpsilonRule, | |
| nn.Linear: EpsilonRule, | |
| nn.BatchNorm2d: EpsilonRule, | |
| Addition_Module: EpsilonRule, | |
| } | |
| SUPPORTED_NON_LINEAR_LAYERS = [nn.ReLU, nn.Dropout, nn.Tanh] | |