Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| from abc import ABC, abstractmethod | |
| from typing import Any | |
| from torch.nn import Module | |
| from torch.utils.data import Dataset | |
| class DataInfluence(ABC): | |
| r""" | |
| An abstract class to define model data influence skeleton. | |
| """ | |
| def __init_( | |
| self, model: Module, influence_src_dataset: Dataset, **kwargs: Any | |
| ) -> None: | |
| r""" | |
| Args: | |
| model (torch.nn.Module): An instance of pytorch model. | |
| influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is | |
| used to create a PyTorch Dataloader to iterate over the dataset and | |
| its labels. This is the dataset for which we will be seeking for | |
| influential instances. In most cases this is the training dataset. | |
| **kwargs: Additional key-value arguments that are necessary for specific | |
| implementation of `DataInfluence` abstract class. | |
| """ | |
| self.model = model | |
| self.influence_src_dataset = influence_src_dataset | |
| def influence(self, inputs: Any = None, **kwargs: Any) -> Any: | |
| r""" | |
| Args: | |
| inputs (Any): Batch of examples for which influential | |
| instances are computed. They are passed to the forward_func. If | |
| `inputs` if a tensor or tuple of tensors, the first dimension | |
| of a tensor corresponds to the batch dimension. | |
| **kwargs: Additional key-value arguments that are necessary for specific | |
| implementation of `DataInfluence` abstract class. | |
| Returns: | |
| influences (Any): We do not add restrictions on the return type for now, | |
| though this may change in the future. | |
| """ | |
| pass | |