| import re |
| import sys |
| from torch.utils.data import Dataset |
| from torchvision.datasets import CIFAR10 |
| import torchvision.transforms as transforms |
|
|
|
|
| class BinaryClassifierDataset(Dataset): |
| def __init__(self, root, train, optimize_class: list): |
| self.optimize_class = optimize_class |
| self.dataset = CIFAR10( |
| root=root, |
| train=train, |
| download=True, |
| transform=transforms.Compose([ |
| transforms.Resize(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.AutoAugment(policy=transforms.AutoAugmentPolicy("cifar10")), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| ]) |
| ) |
|
|
| def __getitem__(self, index): |
| img, origin_target = self.dataset[index] |
| target = 1 if origin_target in self.optimize_class else 0 |
| return img, target |
|
|
| def __len__(self): |
| return self.dataset.__len__() |
|
|
|
|
| def get_optimize_class(): |
| try: |
| string = sys.argv[1] |
| except IndexError: |
| RuntimeError("sys.argv[1] not found") |
| class_int_string = str(re.search(r'class(\d+)', string).group(1)).zfill(4) |
| one_hot_string = bin(int(class_int_string))[2:].zfill(10) |
| optimize_class = [index for index, i in enumerate(one_hot_string) if i == "1"] |
| return list(optimize_class), class_int_string |