Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from aif360.datasets import BinaryLabelDataset | |
| from aif360.algorithms import Transformer | |
| class ARTClassifier(Transformer): | |
| """Wraps an instance of an :obj:`art.classifiers.Classifier` to extend | |
| :obj:`~aif360.algorithms.Transformer`. | |
| """ | |
| def __init__(self, art_classifier): | |
| """Initialize ARTClassifier. | |
| Args: | |
| art_classifier (art.classifier.Classifier): A Classifier | |
| object from the `adversarial-robustness-toolbox`_. | |
| .. _adversarial-robustness-toolbox: | |
| https://github.com/Trusted-AI/adversarial-robustness-toolbox | |
| """ | |
| super(ARTClassifier, self).__init__(art_classifier=art_classifier) | |
| self._art_classifier = art_classifier | |
| def fit(self, dataset, batch_size=128, nb_epochs=20): | |
| """Train a classifer on the input. | |
| Args: | |
| dataset (Dataset): Training dataset. | |
| batch_size (int): Size of batches (passed through to ART). | |
| nb_epochs (int): Number of epochs to use for training (passed | |
| through to ART). | |
| Returns: | |
| ARTClassifier: Returns self. | |
| """ | |
| self._art_classifier.fit(dataset.features, dataset.labels, | |
| batch_size=batch_size, nb_epochs=nb_epochs) | |
| return self | |
| def predict(self, dataset, logits=False): | |
| """Perform prediction for the input. | |
| Args: | |
| dataset (Dataset): Test dataset. | |
| logits (bool, optional): True is prediction should be done at the | |
| logits layer (passed through to ART). | |
| Returns: | |
| Dataset: Dataset with predicted labels in the `labels` field. | |
| """ | |
| pred_labels = self._art_classifier.predict(dataset.features, | |
| dataset.labels, logits=logits) | |
| if isinstance(dataset, BinaryLabelDataset): | |
| pred_labels = np.argmax(pred_labels, axis=1).reshape((-1, 1)) | |
| pred_dataset = dataset.copy() | |
| pred_dataset.labels = pred_labels | |
| return pred_dataset | |