Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import warnings | |
| from functools import partial | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import captum._utils.common as common | |
| import torch | |
| from captum._utils.av import AV | |
| from captum.attr import LayerActivation | |
| from captum.influence._core.influence import DataInfluence | |
| from torch import Tensor | |
| from torch.nn import Module | |
| from torch.utils.data import DataLoader, Dataset | |
| r""" | |
| Additional helper functions to calculate similarity metrics. | |
| """ | |
| def euclidean_distance(test, train) -> Tensor: | |
| r""" | |
| Calculates the pairwise euclidean distance for batches of feature vectors. | |
| Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). | |
| Returns pairwise euclidean distance Tensor of shape (batch_size_1, batch_size_2). | |
| """ | |
| similarity = torch.cdist( | |
| test.view(test.shape[0], -1).unsqueeze(0), | |
| train.view(train.shape[0], -1).unsqueeze(0), | |
| ).squeeze(0) | |
| return similarity | |
| def cosine_similarity(test, train, replace_nan=0) -> Tensor: | |
| r""" | |
| Calculates the pairwise cosine similarity for batches of feature vectors. | |
| Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). | |
| Returns pairwise cosine similarity Tensor of shape (batch_size_1, batch_size_2). | |
| """ | |
| test = test.view(test.shape[0], -1) | |
| train = train.view(train.shape[0], -1) | |
| if torch.__version__ <= "1.6.0": | |
| test_norm = torch.norm(test, p=None, dim=1, keepdim=True) | |
| train_norm = torch.norm(train, p=None, dim=1, keepdim=True) | |
| else: | |
| test_norm = torch.linalg.norm(test, ord=2, dim=1, keepdim=True) | |
| train_norm = torch.linalg.norm(train, ord=2, dim=1, keepdim=True) | |
| test = torch.where(test_norm != 0.0, test / test_norm, Tensor([replace_nan])) | |
| train = torch.where(train_norm != 0.0, train / train_norm, Tensor([replace_nan])).T | |
| similarity = torch.mm(test, train) | |
| return similarity | |
| r""" | |
| Implements abstract DataInfluence class and provides implementation details for | |
| similarity metric-based influence computation. Similarity metrics can be used to compare | |
| intermediate or final activation vectors of a model for different sets of input. Then, | |
| these can be used to draw conclusions about influential instances. | |
| Some standard similarity metrics such as dot product similarity or euclidean distance | |
| are provided, but the user can provide any custom similarity metric as well. | |
| """ | |
| class SimilarityInfluence(DataInfluence): | |
| def __init__( | |
| self, | |
| module: Module, | |
| layers: Union[str, List[str]], | |
| influence_src_dataset: Dataset, | |
| activation_dir: str, | |
| model_id: str = "", | |
| similarity_metric: Callable = cosine_similarity, | |
| similarity_direction: str = "max", | |
| batch_size: int = 1, | |
| **kwargs: Any, | |
| ): | |
| r""" | |
| Args: | |
| module (torch.nn.Module): An instance of pytorch model. This model should | |
| define all of its layers as attributes of the model. | |
| layers (str or List of str): The fully qualified layer(s) for which the | |
| activation vectors are computed. | |
| influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is | |
| used to create a PyTorch Dataloader to iterate over the dataset and | |
| its labels. This is the dataset for which we will be seeking for | |
| influential instances. In most cases this is the training dataset. | |
| activation_dir (str): The directory of the path to store | |
| and retrieve activation computations. Best practice would be to use | |
| an absolute path. | |
| model_id (str): The name/version of the model for which layer | |
| activations are being computed. Activations will be stored and | |
| loaded under the subdirectory with this name if provided. | |
| similarity_metric (Callable): This is a callable function that computes a | |
| similarity metric between two representations. For example, the | |
| representations pair could be from the training and test sets. | |
| This function must adhere to certain standards. The inputs should be | |
| torch Tensors with shape (batch_size_i/j, feature dimensions). The | |
| output Tensor should have shape (batch_size_i, batch_size_j) with | |
| scalar values corresponding to the similarity metric used for each | |
| pairwise combination from the two batches. | |
| For example, suppose we use `batch_size_1 = 16` for iterating | |
| through `influence_src_dataset`, and for the `inputs` argument | |
| we pass in a Tensor with 3 examples, i.e. batch_size_2 = 3. Also, | |
| suppose that our inputs and intermediate activations throughout the | |
| model will have dimension (N, C, H, W). Then, the feature dimensions | |
| should be flattened within this function. For example:: | |
| >>> av_test.shape | |
| torch.Size([3, N, C, H, W]) | |
| >>> av_src.shape | |
| torch.Size([16, N, C, H, W]) | |
| >>> av_test = torch.view(av_test.shape[0], -1) | |
| >>> av_test.shape | |
| torch.Size([3, N x C x H x W]) | |
| and similarly for av_src. The similarity_metric should then use | |
| these flattened tensors to return the pairwise similarity matrix. | |
| For example, `similarity_metric(av_test, av_src)` should return a | |
| tensor of shape (3, 16). | |
| batch_size (int): Batch size for iterating through `influence_src_dataset`. | |
| **kwargs: Additional key-value arguments that are necessary for specific | |
| implementation of `DataInfluence` abstract class. | |
| """ | |
| self.module = module | |
| self.layers = [layers] if isinstance(layers, str) else layers | |
| self.influence_src_dataset = influence_src_dataset | |
| self.activation_dir = activation_dir | |
| self.model_id = model_id | |
| self.batch_size = batch_size | |
| if similarity_direction == "max" or similarity_direction == "min": | |
| self.similarity_direction = similarity_direction | |
| else: | |
| raise ValueError( | |
| f"{similarity_direction} is not a valid value. " | |
| "Must be either 'max' or 'min'" | |
| ) | |
| if similarity_metric is cosine_similarity: | |
| if "replace_nan" in kwargs: | |
| self.replace_nan = kwargs["replace_nan"] | |
| else: | |
| self.replace_nan = -2 if self.similarity_direction == "max" else 2 | |
| similarity_metric = partial(cosine_similarity, replace_nan=self.replace_nan) | |
| self.similarity_metric = similarity_metric | |
| self.influence_src_dataloader = DataLoader( | |
| influence_src_dataset, batch_size, shuffle=False | |
| ) | |
| def influence( # type: ignore[override] | |
| self, | |
| inputs: Union[Tensor, Tuple[Tensor, ...]], | |
| top_k: int = 1, | |
| additional_forward_args: Optional[Any] = None, | |
| load_src_from_disk: bool = True, | |
| **kwargs: Any, | |
| ) -> Dict: | |
| r""" | |
| Args: | |
| inputs (tensor or tuple of tensors): Batch of examples for which influential | |
| instances are computed. They are passed to the forward_func. The | |
| first dimension in `inputs` tensor or tuple of tensors corresponds | |
| to the batch size. A tuple of tensors is only passed in if this | |
| is the input form that `module` accepts. | |
| top_k (int): The number of top-matching activations to return | |
| additional_forward_args (optional): Additional arguments that will be | |
| passed to forward_func after inputs. | |
| load_src_from_disk (bool): Loads activations for `influence_src_dataset` | |
| where possible. Setting to False would force regeneration of | |
| activations. | |
| load_input_from_disk (bool): Regenerates activations for inputs by default | |
| and removes previous `inputs` activations that are flagged with | |
| `inputs_id`. Setting to True will load prior matching inputs | |
| activations. Note that this could lead to unexpected behavior if | |
| `inputs_id` is not configured properly and activations are loaded | |
| for a different, prior `inputs`. | |
| inputs_id (str): Used to identify inputs for loading activations. | |
| **kwargs: Additional key-value arguments that are necessary for specific | |
| implementation of `DataInfluence` abstract class. | |
| Returns: | |
| influences (dict): Returns the influential instances retrieved from | |
| `influence_src_dataset` for each test example represented through a | |
| tensor or a tuple of tensor in `inputs`. Returned influential | |
| examples are represented as dict, with keys corresponding to | |
| the layer names passed in `layers`. Each value in the dict is a | |
| tuple containing the indices and values for the top k similarities | |
| from `influence_src_dataset` by the chosen metric. The first value | |
| in the tuple corresponds to the indices corresponding to the top k | |
| most similar examples, and the second value is the similarity score. | |
| The batch dimension corresponds to the batch dimension of `inputs`. | |
| If inputs.shape[0] == 5, then dict[`layer_name`][0].shape[0] == 5. | |
| These tensors will be of shape (inputs.shape[0], top_k). | |
| """ | |
| inputs_batch_size = ( | |
| inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] | |
| ) | |
| influences: Dict[str, Any] = {} | |
| layer_AVDatasets = AV.generate_dataset_activations( | |
| self.activation_dir, | |
| self.module, | |
| self.model_id, | |
| self.layers, | |
| DataLoader(self.influence_src_dataset, self.batch_size, shuffle=False), | |
| identifier="src", | |
| load_from_disk=load_src_from_disk, | |
| return_activations=True, | |
| ) | |
| assert layer_AVDatasets is not None and not isinstance( | |
| layer_AVDatasets, AV.AVDataset | |
| ) | |
| layer_modules = [ | |
| common._get_module_from_name(self.module, layer) for layer in self.layers | |
| ] | |
| test_activations = LayerActivation(self.module, layer_modules).attribute( | |
| inputs, additional_forward_args | |
| ) | |
| minmax = self.similarity_direction == "max" | |
| # av_inputs shape: (inputs_batch_size, *) e.g. (inputs_batch_size, N, C, H, W) | |
| # av_src shape: (self.batch_size, *) e.g. (self.batch_size, N, C, H, W) | |
| test_activations = ( | |
| test_activations if len(self.layers) > 1 else [test_activations] | |
| ) | |
| for i, (layer, layer_AVDataset) in enumerate( | |
| zip(self.layers, layer_AVDatasets) | |
| ): | |
| topk_val, topk_idx = torch.Tensor(), torch.Tensor().long() | |
| zero_acts = torch.Tensor().long() | |
| av_inputs = test_activations[i] | |
| src_loader = DataLoader(layer_AVDataset) | |
| for j, av_src in enumerate(src_loader): | |
| av_src = av_src.squeeze(0) | |
| similarity = self.similarity_metric(av_inputs, av_src) | |
| msg = ( | |
| "Output of custom similarity does not meet required dimensions. " | |
| f"Your output has shape {similarity.shape}.\nPlease ensure the " | |
| "output shape matches (inputs_batch_size, src_dataset_batch_size), " | |
| f"which should be {(inputs_batch_size, self.batch_size)}." | |
| ) | |
| assert similarity.shape == (inputs_batch_size, av_src.shape[0]), msg | |
| if hasattr(self, "replace_nan"): | |
| idx = (similarity == self.replace_nan).nonzero() | |
| zero_acts = torch.cat((zero_acts, idx)) | |
| r""" | |
| TODO: For models that can have tuples as activations, we should | |
| allow similarity metrics to accept tuples, support topk selection. | |
| """ | |
| topk_batch = min(top_k, self.batch_size) | |
| values, indices = torch.topk( | |
| similarity, topk_batch, dim=1, largest=minmax | |
| ) | |
| indices += int(j * self.batch_size) | |
| topk_val = torch.cat((topk_val, values), dim=1) | |
| topk_idx = torch.cat((topk_idx, indices), dim=1) | |
| # can modify how often to sort for efficiency? minor | |
| sort_idx = torch.argsort(topk_val, dim=1, descending=minmax) | |
| topk_val = torch.gather(topk_val, 1, sort_idx[:, :top_k]) | |
| topk_idx = torch.gather(topk_idx, 1, sort_idx[:, :top_k]) | |
| influences[layer] = (topk_idx, topk_val) | |
| if torch.numel(zero_acts != 0): | |
| zero_warning = ( | |
| f"Layer {layer} has zero-vector activations for some inputs. This " | |
| "may cause undefined behavior for cosine similarity. The indices " | |
| "for the offending inputs will be included under the key " | |
| f"'zero_acts-{layer}' in the output dictionary. Indices are " | |
| "returned as a tensor with [inputs_idx, src_dataset_idx] pairs " | |
| "which may have corrupted similarity scores." | |
| ) | |
| warnings.warn(zero_warning, RuntimeWarning) | |
| key = "-".join(["zero_acts", layer]) | |
| influences[key] = zero_acts | |
| return influences | |