Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.utils.data as data | |
| class TopKClassificationWrapper(data.Dataset): | |
| def __init__(self, dataset: data.Dataset, attack_labels, seed=0, k=1) -> None: | |
| super().__init__() | |
| self.generator = torch.Generator("cpu") | |
| self.generator.manual_seed(seed) | |
| # Pregenerate attack labels | |
| num_classes = len(dataset.classes) | |
| self.src_dataset = dataset | |
| self.attack_labels = attack_labels | |
| def __getitem__(self, index): | |
| image, label = self.src_dataset[index] | |
| return image, label, self.attack_labels[index], index | |
| def __len__(self): | |
| return len(self.src_dataset) |