File size: 3,496 Bytes
377dccd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch.nn.functional as F
import torchvision.transforms as transforms
from backbone.MNISTMLP import MNISTMLP
from PIL import Image
from torchvision.datasets import MNIST
from datasets.utils.continual_dataset import (ContinualDataset,
store_masked_loaders)
from datasets.utils.validation import get_train_val
from utils.conf import base_path_dataset as base_path
class MyMNIST(MNIST):
"""
Overrides the MNIST dataset to change the getitem function.
"""
def __init__(self, root, train=True, transform=None,
target_transform=None, download=False) -> None:
self.not_aug_transform = transforms.ToTensor()
super(MyMNIST, self).__init__(root, train,
transform, target_transform, download)
def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
"""
Gets the requested element from the dataset.
:param index: index of the element to be returned
:returns: tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
original_img = self.not_aug_transform(img.copy())
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if hasattr(self, 'logits'):
return img, target, original_img, self.logits[index]
return img, target, original_img
class SequentialMNIST(ContinualDataset):
NAME = 'seq-mnist'
SETTING = 'class-il'
N_CLASSES_PER_TASK = 2
N_TASKS = 5
TRANSFORM = None
def get_data_loaders(self):
transform = transforms.ToTensor()
train_dataset = MyMNIST(base_path() + 'MNIST',
train=True, download=True, transform=transform)
if self.args.validation:
train_dataset, test_dataset = get_train_val(train_dataset,
transform, self.NAME)
else:
test_dataset = MNIST(base_path() + 'MNIST',
train=False, download=True, transform=transform)
train, test = store_masked_loaders(train_dataset, test_dataset, self)
return train, test
@staticmethod
def get_backbone():
return MNISTMLP(28 * 28, SequentialMNIST.N_TASKS
* SequentialMNIST.N_CLASSES_PER_TASK)
@staticmethod
def get_transform():
return None
@staticmethod
def get_loss():
return F.cross_entropy
@staticmethod
def get_normalization_transform():
return None
@staticmethod
def get_denormalization_transform():
return None
@staticmethod
def get_scheduler(model, args):
return None
@staticmethod
def get_batch_size():
return 64
@staticmethod
def get_minibatch_size():
return SequentialMNIST.get_batch_size()
|