Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import typing | |
| import warnings | |
| from typing import Any, Callable, Iterator, Tuple, Union | |
| import torch | |
| from captum._utils.common import ( | |
| _format_additional_forward_args, | |
| _format_output, | |
| _format_tensor_into_tuples, | |
| _reduce_list, | |
| ) | |
| from captum._utils.typing import ( | |
| TargetType, | |
| TensorOrTupleOfTensorsGeneric, | |
| TupleOrTensorOrBoolGeneric, | |
| ) | |
| from captum.attr._utils.approximation_methods import approximation_parameters | |
| from torch import Tensor | |
| def _batch_attribution( | |
| attr_method, | |
| num_examples, | |
| internal_batch_size, | |
| n_steps, | |
| include_endpoint=False, | |
| **kwargs, | |
| ): | |
| """ | |
| This method applies internal batching to given attribution method, dividing | |
| the total steps into batches and running each independently and sequentially, | |
| adding each result to compute the total attribution. | |
| Step sizes and alphas are spliced for each batch and passed explicitly for each | |
| call to _attribute. | |
| kwargs include all argument necessary to pass to each attribute call, except | |
| for n_steps, which is computed based on the number of steps for the batch. | |
| include_endpoint ensures that one step overlaps between each batch, which | |
| is necessary for some methods, particularly LayerConductance. | |
| """ | |
| if internal_batch_size < num_examples: | |
| warnings.warn( | |
| "Internal batch size cannot be less than the number of input examples. " | |
| "Defaulting to internal batch size of %d equal to the number of examples." | |
| % num_examples | |
| ) | |
| # Number of steps for each batch | |
| step_count = max(1, internal_batch_size // num_examples) | |
| if include_endpoint: | |
| if step_count < 2: | |
| step_count = 2 | |
| warnings.warn( | |
| "This method computes finite differences between evaluations at " | |
| "consecutive steps, so internal batch size must be at least twice " | |
| "the number of examples. Defaulting to internal batch size of %d" | |
| " equal to twice the number of examples." % (2 * num_examples) | |
| ) | |
| total_attr = None | |
| cumulative_steps = 0 | |
| step_sizes_func, alphas_func = approximation_parameters(kwargs["method"]) | |
| full_step_sizes = step_sizes_func(n_steps) | |
| full_alphas = alphas_func(n_steps) | |
| while cumulative_steps < n_steps: | |
| start_step = cumulative_steps | |
| end_step = min(start_step + step_count, n_steps) | |
| batch_steps = end_step - start_step | |
| if include_endpoint: | |
| batch_steps -= 1 | |
| step_sizes = full_step_sizes[start_step:end_step] | |
| alphas = full_alphas[start_step:end_step] | |
| current_attr = attr_method._attribute( | |
| **kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas) | |
| ) | |
| if total_attr is None: | |
| total_attr = current_attr | |
| else: | |
| if isinstance(total_attr, Tensor): | |
| total_attr = total_attr + current_attr.detach() | |
| else: | |
| total_attr = tuple( | |
| current.detach() + prev_total | |
| for current, prev_total in zip(current_attr, total_attr) | |
| ) | |
| if include_endpoint and end_step < n_steps: | |
| cumulative_steps = end_step - 1 | |
| else: | |
| cumulative_steps = end_step | |
| return total_attr | |
| def _tuple_splice_range(inputs: None, start: int, end: int) -> None: | |
| ... | |
| def _tuple_splice_range(inputs: Tuple, start: int, end: int) -> Tuple: | |
| ... | |
| def _tuple_splice_range( | |
| inputs: Union[None, Tuple], start: int, end: int | |
| ) -> Union[None, Tuple]: | |
| """ | |
| Splices each tensor element of given tuple (inputs) from range start | |
| (inclusive) to end (non-inclusive) on its first dimension. If element | |
| is not a Tensor, it is left unchanged. It is assumed that all tensor elements | |
| have the same first dimension (corresponding to number of examples). | |
| The returned value is a tuple with the same length as inputs, with Tensors | |
| spliced appropriately. | |
| """ | |
| assert start < end, "Start point must precede end point for batch splicing." | |
| if inputs is None: | |
| return None | |
| return tuple( | |
| inp[start:end] if isinstance(inp, torch.Tensor) else inp for inp in inputs | |
| ) | |
| def _batched_generator( | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| additional_forward_args: Any = None, | |
| target_ind: TargetType = None, | |
| internal_batch_size: Union[None, int] = None, | |
| ) -> Iterator[Tuple[Tuple[Tensor, ...], Any, TargetType]]: | |
| """ | |
| Returns a generator which returns corresponding chunks of size internal_batch_size | |
| for both inputs and additional_forward_args. If batch size is None, | |
| generator only includes original inputs and additional args. | |
| """ | |
| assert internal_batch_size is None or ( | |
| isinstance(internal_batch_size, int) and internal_batch_size > 0 | |
| ), "Batch size must be greater than 0." | |
| inputs = _format_tensor_into_tuples(inputs) | |
| additional_forward_args = _format_additional_forward_args(additional_forward_args) | |
| num_examples = inputs[0].shape[0] | |
| # TODO Reconsider this check if _batched_generator is used for non gradient-based | |
| # attribution algorithms | |
| if not (inputs[0] * 1).requires_grad: | |
| warnings.warn( | |
| """It looks like that the attribution for a gradient-based method is | |
| computed in a `torch.no_grad` block or perhaps the inputs have no | |
| requires_grad.""" | |
| ) | |
| if internal_batch_size is None: | |
| yield inputs, additional_forward_args, target_ind | |
| else: | |
| for current_total in range(0, num_examples, internal_batch_size): | |
| with torch.autograd.set_grad_enabled(True): | |
| inputs_splice = _tuple_splice_range( | |
| inputs, current_total, current_total + internal_batch_size | |
| ) | |
| yield inputs_splice, _tuple_splice_range( | |
| additional_forward_args, | |
| current_total, | |
| current_total + internal_batch_size, | |
| ), target_ind[ | |
| current_total : current_total + internal_batch_size | |
| ] if isinstance( | |
| target_ind, list | |
| ) or ( | |
| isinstance(target_ind, torch.Tensor) and target_ind.numel() > 1 | |
| ) else target_ind | |
| def _batched_operator( | |
| operator: Callable[..., TupleOrTensorOrBoolGeneric], | |
| inputs: TensorOrTupleOfTensorsGeneric, | |
| additional_forward_args: Any = None, | |
| target_ind: TargetType = None, | |
| internal_batch_size: Union[None, int] = None, | |
| **kwargs: Any, | |
| ) -> TupleOrTensorOrBoolGeneric: | |
| """ | |
| Batches the operation of the given operator, applying the given batch size | |
| to inputs and additional forward arguments, and returning the concatenation | |
| of the results of each batch. | |
| """ | |
| all_outputs = [ | |
| operator( | |
| inputs=input, | |
| additional_forward_args=additional, | |
| target_ind=target, | |
| **kwargs, | |
| ) | |
| for input, additional, target in _batched_generator( | |
| inputs, additional_forward_args, target_ind, internal_batch_size | |
| ) | |
| ] | |
| return _reduce_list(all_outputs) | |
| def _select_example(curr_arg: Any, index: int, bsz: int) -> Any: | |
| if curr_arg is None: | |
| return None | |
| is_tuple = isinstance(curr_arg, tuple) | |
| if not is_tuple: | |
| curr_arg = (curr_arg,) | |
| selected_arg = [] | |
| for i in range(len(curr_arg)): | |
| if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz: | |
| selected_arg.append(curr_arg[i][index : index + 1]) | |
| else: | |
| selected_arg.append(curr_arg[i]) | |
| return _format_output(is_tuple, tuple(selected_arg)) | |
| def _batch_example_iterator(bsz: int, *args) -> Iterator: | |
| """ | |
| Batches the provided argument. | |
| """ | |
| for i in range(bsz): | |
| curr_args = [_select_example(args[j], i, bsz) for j in range(len(args))] | |
| yield tuple(curr_args) | |