Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| from collections import defaultdict | |
| from typing import Any, Dict, List, Optional, Union | |
| from captum._utils.common import _format_tensor_into_tuples | |
| from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric | |
| from captum.attr._utils.stat import Stat | |
| from captum.attr._utils.summarizer import Summarizer | |
| from captum.log import log_usage | |
| from torch import Tensor | |
| class ClassSummarizer(Summarizer): | |
| r""" | |
| Used to keep track of summaries for associated classes. The | |
| classes/labels can be of any type that are supported by `dict`. | |
| This also keeps track of an aggregate of all class summaries. | |
| """ | |
| def __init__(self, stats: List[Stat]) -> None: | |
| Summarizer.__init__.__wrapped__(self, stats) | |
| self.summaries: Dict[Any, Summarizer] = defaultdict( | |
| lambda: Summarizer(stats=stats) | |
| ) | |
| def update( # type: ignore | |
| self, | |
| x: TensorOrTupleOfTensorsGeneric, | |
| labels: TargetType = None, | |
| ): | |
| r""" | |
| Updates the stats of the summarizer, optionally associated to classes. | |
| This accepts either a single tensor to summarise or a tuple of tensors. | |
| Args: | |
| x (Tensor or Tuple[Tensor, ...]): | |
| The input tensor to be summarised. The first | |
| dimension of this input must be associated to | |
| the batch size of the inputs. | |
| labels (int, tuple, tensor or list, optional): | |
| The associated labels for `x`. If Any, we | |
| assume `labels` represents the label for all inputs in `x`. | |
| If this is None we simply aggregate the total summary. | |
| """ | |
| if labels is None: | |
| super().update(x) | |
| return | |
| x = _format_tensor_into_tuples(x) | |
| num_labels = 1 | |
| labels_typed: Union[List[Any], Tensor] | |
| if isinstance(labels, list) or isinstance(labels, Tensor): | |
| labels_typed = labels | |
| num_labels = len(labels) # = labels.size(0) if tensor | |
| else: | |
| labels_typed = [labels] | |
| # mypy doesn't realise I have made the int a list | |
| if len(labels_typed) > 1: | |
| for x_i in x: | |
| assert x_i.size(0) == num_labels, ( | |
| "batch size does not equal amount of labels; " | |
| "please ensure length of labels is equal to 1 " | |
| "or to the `batch_size` corresponding to the " | |
| "number of examples in the input(s)" | |
| ) | |
| batch_size = x[0].size(0) | |
| for i in range(batch_size): | |
| tensors_to_summarize = tuple(tensor[i] for tensor in x) | |
| tensors_to_summarize_copy = tuple(tensor[i].clone() for tensor in x) | |
| label = labels_typed[0] if len(labels_typed) == 1 else labels_typed[i] | |
| self.summaries[label].update(tensors_to_summarize) | |
| super().update(tensors_to_summarize_copy) | |
| def class_summaries( | |
| self, | |
| ) -> Dict[ | |
| Any, Union[None, Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]] | |
| ]: | |
| r""" | |
| Returns: | |
| The summaries for each class. | |
| """ | |
| return {key: value.summary for key, value in self.summaries.items()} | |