| |
| |
| import os |
| import json |
|
|
| from torchvision import datasets, transforms |
| from torchvision.datasets.folder import ImageFolder, default_loader |
|
|
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| from timm.data import create_transform |
|
|
|
|
| class INatDataset(ImageFolder): |
| def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, |
| category='name', loader=default_loader): |
| self.transform = transform |
| self.loader = loader |
| self.target_transform = target_transform |
| self.year = year |
| |
| path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') |
| with open(path_json) as json_file: |
| data = json.load(json_file) |
|
|
| with open(os.path.join(root, 'categories.json')) as json_file: |
| data_catg = json.load(json_file) |
|
|
| path_json_for_targeter = os.path.join(root, f"train{year}.json") |
|
|
| with open(path_json_for_targeter) as json_file: |
| data_for_targeter = json.load(json_file) |
|
|
| targeter = {} |
| indexer = 0 |
| for elem in data_for_targeter['annotations']: |
| king = [] |
| king.append(data_catg[int(elem['category_id'])][category]) |
| if king[0] not in targeter.keys(): |
| targeter[king[0]] = indexer |
| indexer += 1 |
| self.nb_classes = len(targeter) |
|
|
| self.samples = [] |
| for elem in data['images']: |
| cut = elem['file_name'].split('/') |
| target_current = int(cut[2]) |
| path_current = os.path.join(root, cut[0], cut[2], cut[3]) |
|
|
| categors = data_catg[target_current] |
| target_current_true = targeter[categors[category]] |
| self.samples.append((path_current, target_current_true)) |
|
|
| |
|
|
|
|
| def build_dataset(is_train, args): |
| transform = build_transform(is_train, args) |
|
|
| if args.data_set == 'CIFAR': |
| dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) |
| nb_classes = 100 |
| elif args.data_set == 'IMNET': |
| root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| dataset = datasets.ImageFolder(root, transform=transform) |
| nb_classes = 1000 |
| elif args.data_set == 'INAT': |
| dataset = INatDataset(args.data_path, train=is_train, year=2018, |
| category=args.inat_category, transform=transform) |
| nb_classes = dataset.nb_classes |
| elif args.data_set == 'INAT19': |
| dataset = INatDataset(args.data_path, train=is_train, year=2019, |
| category=args.inat_category, transform=transform) |
| nb_classes = dataset.nb_classes |
|
|
| return dataset, nb_classes |
|
|
|
|
| def build_transform(is_train, args): |
| resize_im = args.input_size > 32 |
| if is_train: |
| |
| transform = create_transform( |
| input_size=args.input_size, |
| is_training=True, |
| color_jitter=args.color_jitter, |
| auto_augment=args.aa, |
| interpolation=args.train_interpolation, |
| re_prob=args.reprob, |
| re_mode=args.remode, |
| re_count=args.recount, |
| ) |
| if not resize_im: |
| |
| |
| transform.transforms[0] = transforms.RandomCrop( |
| args.input_size, padding=4) |
| return transform |
|
|
| t = [] |
| if resize_im: |
| size = int(args.input_size / args.eval_crop_ratio) |
| t.append( |
| transforms.Resize(size, interpolation=3), |
| ) |
| t.append(transforms.CenterCrop(args.input_size)) |
|
|
| t.append(transforms.ToTensor()) |
| t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) |
| return transforms.Compose(t) |
|
|