| |
| |
| |
| |
| from backbone.ResNet18_id2 import resnet18_id2 |
| import os |
| from typing import Optional |
| import torch.optim |
| import numpy as np |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| from backbone.ResNet18 import resnet18 |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from datasets.transforms.denormalization import DeNormalize |
| from datasets.utils.continual_dataset import (ContinualDataset, |
| store_masked_loaders) |
| from datasets.utils.validation import get_train_val |
| from utils.conf import base_path_dataset as base_path |
| from torchvision.models import mobilenet_v2 |
| import torch |
|
|
| class TinyImagenet(Dataset): |
| """ |
| Defines Tiny Imagenet as for the others pytorch datasets. |
| """ |
| def __init__(self, root: str, train: bool=True, transform: transforms=None, |
| target_transform: transforms=None, download: bool=False) -> None: |
| self.not_aug_transform = transforms.Compose([transforms.ToTensor()]) |
| self.root = root |
| self.train = train |
| self.transform = transform |
| self.target_transform = target_transform |
| self.download = download |
|
|
| if download: |
| if os.path.isdir(root) and len(os.listdir(root)) > 0: |
| print('Download not needed, files already on disk.') |
| else: |
| from onedrivedownloader import download |
|
|
| print('Downloading dataset') |
| ln = "https://unimore365-my.sharepoint.com/:u:/g/personal/263133_unimore_it/EVKugslStrtNpyLGbgrhjaABqRHcE3PB_r2OEaV7Jy94oQ?e=9K29aD" |
| download(ln, filename=os.path.join(root, 'tiny-imagenet-processed.zip'), unzip=True, unzip_path=root, clean=True) |
|
|
| self.data = [] |
| for num in range(20): |
| self.data.append(np.load(os.path.join( |
| root, 'processed/x_%s_%02d.npy' % |
| ('train' if self.train else 'val', num+1)))) |
| self.data = np.concatenate(np.array(self.data)) |
|
|
| self.targets = [] |
| for num in range(20): |
| self.targets.append(np.load(os.path.join( |
| root, 'processed/y_%s_%02d.npy' % |
| ('train' if self.train else 'val', num+1)))) |
| self.targets = np.concatenate(np.array(self.targets)) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, index): |
| img, target = self.data[index], self.targets[index] |
|
|
| |
| |
| img = Image.fromarray(np.uint8(255 * img)) |
| original_img = img.copy() |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| if self.target_transform is not None: |
| target = self.target_transform(target) |
|
|
| if hasattr(self, 'logits'): |
| return img, target, original_img, self.logits[index] |
| |
| return img, target |
|
|
|
|
| class MyTinyImagenet(TinyImagenet): |
| """ |
| Defines Tiny Imagenet as for the others pytorch datasets. |
| """ |
| def __init__(self, root: str, train: bool=True, transform: transforms=None, |
| target_transform: transforms=None, download: bool=False) -> None: |
| super(MyTinyImagenet, self).__init__( |
| root, train, transform, target_transform, download) |
|
|
| def __getitem__(self, index): |
| img, target = self.data[index], self.targets[index] |
|
|
| |
| |
| img = Image.fromarray(np.uint8(255 * img)) |
| original_img = img.copy() |
|
|
| not_aug_img = self.not_aug_transform(original_img) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| if self.target_transform is not None: |
| target = self.target_transform(target) |
| |
| if hasattr(self, 'logits'): |
| return img, target, not_aug_img, self.logits[index] |
|
|
| return img, target, not_aug_img |
|
|
|
|
| class SequentialTinyImagenet(ContinualDataset): |
|
|
| NAME = 'seq-tinyimg' |
| SETTING = 'class-il' |
| N_CLASSES_PER_TASK = 20 |
| N_TASKS = 10 |
| N_CLASSES=200 |
| N_CLASSES_PER_TASK = N_CLASSES // N_TASKS |
| TRANSFORM = transforms.Compose( |
| [transforms.RandomCrop(64, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.4802, 0.4480, 0.3975), |
| (0.2770, 0.2691, 0.2821))]) |
|
|
| def get_data_loaders(self): |
| transform = self.TRANSFORM |
|
|
| test_transform = transforms.Compose( |
| [transforms.ToTensor(), self.get_normalization_transform()]) |
|
|
| train_dataset = MyTinyImagenet(base_path() + 'TINYIMG', |
| train=True, download=True, transform=transform) |
| if self.args.validation: |
| train_dataset, test_dataset = get_train_val(train_dataset, |
| test_transform, self.NAME) |
| else: |
| test_dataset = TinyImagenet(base_path() + 'TINYIMG', |
| train=False, download=True, transform=test_transform) |
|
|
| train, test = store_masked_loaders(train_dataset, test_dataset, self) |
| return train, test |
|
|
|
|
| @staticmethod |
| def get_backbone(): |
| return resnet18(SequentialTinyImagenet.N_CLASSES_PER_TASK |
| * SequentialTinyImagenet.N_TASKS) |
| def get_backboneid(self): |
| return resnet18_id2(SequentialTinyImagenet.N_CLASSES_PER_TASK |
| * SequentialTinyImagenet.N_TASKS) |
| |
| @staticmethod |
| def get_loss(): |
| return F.cross_entropy |
|
|
| def get_transform(self): |
| transform = transforms.Compose( |
| [transforms.ToPILImage(), self.TRANSFORM]) |
| return transform |
|
|
| @staticmethod |
| def get_normalization_transform(): |
| transform = transforms.Normalize((0.4802, 0.4480, 0.3975), |
| (0.2770, 0.2691, 0.2821)) |
| return transform |
|
|
| @staticmethod |
| def get_denormalization_transform(): |
| transform = DeNormalize((0.4802, 0.4480, 0.3975), |
| (0.2770, 0.2691, 0.2821)) |
| return transform |
| @staticmethod |
| def get_epochs(): |
| return 100 |
|
|
| @staticmethod |
| def get_batch_size(): |
| return 32 |
|
|
| @staticmethod |
| def get_minibatch_size(): |
| return SequentialTinyImagenet.get_batch_size() |
| @staticmethod |
| def get_scheduler(model, args) -> torch.optim.lr_scheduler: |
| if args.n_epochs==50: |
| model.opt = torch.optim.SGD(model.net.parameters(), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom) |
| scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 45], gamma=0.1, verbose=False) |
| else: |
| model.opt = torch.optim.SGD(model.net.parameters(), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom) |
| scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 60, 75], gamma=0.1, verbose=False) |
| return scheduler |
| |