AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
import logging
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.utils.data import Dataset
from torchvision import transforms
from utils.data import CDDB_benchmark, TrueFake_benchmark
import pickle
import os
class DataManager(object):
def __init__(self, dataset_name, shuffle, seed, init_cls, increment, args):
self.args = args
self.dataset_name = dataset_name
self._setup_data(dataset_name, shuffle, seed)
assert init_cls <= len(self._class_order), "No enough classes."
self._increments = [init_cls]
while sum(self._increments) + increment < len(self._class_order):
self._increments.append(increment)
offset = len(self._class_order) - sum(self._increments)
if offset > 0:
self._increments.append(offset)
@property
def nb_tasks(self):
return len(self._increments)
def get_task_size(self, task):
return self._increments[task]
def get_dataset(self, indices, source, mode, appendent=None, ret_data=False):
if source == "train":
x, y = self._train_data, self._train_targets
elif source == "test":
x, y = self._test_data, self._test_targets
else:
raise ValueError("Unknown data source {}.".format(source))
if mode == "train":
trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
elif mode == "flip":
trsf = transforms.Compose(
[
*self._test_trsf,
transforms.RandomHorizontalFlip(p=1.0),
*self._common_trsf,
]
)
elif mode == "test":
trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
else:
raise ValueError("Unknown mode {}.".format(mode))
data, targets = [], []
for idx in indices:
class_data, class_targets = self._select(
x, y, low_range=idx, high_range=idx + 1
)
data.append(class_data)
targets.append(class_targets)
if appendent is not None and len(appendent) != 0:
appendent_data, appendent_targets = appendent
data.append(appendent_data)
targets.append(appendent_targets)
data, targets = np.concatenate(data), np.concatenate(targets)
# if ret_data:
# return data, targets, DummyDataset(data, targets, trsf, self.use_path)
# else:
return DummyDataset(
data,
targets,
trsf,
self._object_classes_data,
self.use_path,
self.args,
)
def get_anchor_dataset(self, mode, appendent=None, ret_data=False):
if mode == "train":
trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
elif mode == "flip":
trsf = transforms.Compose(
[
*self._test_trsf,
transforms.RandomHorizontalFlip(p=1.0),
*self._common_trsf,
]
)
elif mode == "test":
trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
else:
raise ValueError("Unknown mode {}.".format(mode))
data, targets = [], []
if appendent is not None and len(appendent) != 0:
appendent_data, appendent_targets = appendent
data.append(appendent_data)
targets.append(appendent_targets)
data, targets = np.concatenate(data), np.concatenate(targets)
if ret_data:
return data, targets, DummyDataset(data, targets, trsf, self.use_path)
else:
return DummyDataset(data, targets, trsf, self.use_path)
def get_dataset_with_split(
self, indices, source, mode, appendent=None, val_samples_per_class=0
):
if source == "train":
x, y = self._train_data, self._train_targets
elif source == "test":
x, y = self._test_data, self._test_targets
else:
raise ValueError("Unknown data source {}.".format(source))
if mode == "train":
trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
elif mode == "test":
trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
else:
raise ValueError("Unknown mode {}.".format(mode))
train_data, train_targets = [], []
val_data, val_targets = [], []
for idx in indices:
class_data, class_targets = self._select(
x, y, low_range=idx, high_range=idx + 1
)
val_indx = np.random.choice(
len(class_data), val_samples_per_class, replace=False
)
train_indx = list(set(np.arange(len(class_data))) - set(val_indx))
val_data.append(class_data[val_indx])
val_targets.append(class_targets[val_indx])
train_data.append(class_data[train_indx])
train_targets.append(class_targets[train_indx])
if appendent is not None:
appendent_data, appendent_targets = appendent
for idx in range(0, int(np.max(appendent_targets)) + 1):
append_data, append_targets = self._select(
appendent_data, appendent_targets, low_range=idx, high_range=idx + 1
)
val_indx = np.random.choice(
len(append_data), val_samples_per_class, replace=False
)
train_indx = list(set(np.arange(len(append_data))) - set(val_indx))
val_data.append(append_data[val_indx])
val_targets.append(append_targets[val_indx])
train_data.append(append_data[train_indx])
train_targets.append(append_targets[train_indx])
train_data, train_targets = np.concatenate(train_data), np.concatenate(
train_targets
)
val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets)
return DummyDataset(
train_data, train_targets, trsf, self.use_path
), DummyDataset(val_data, val_targets, trsf, self.use_path)
def _setup_data(self, dataset_name, shuffle, seed):
idata = _get_idata(dataset_name, self.args)
idata.download_data()
# Data
self._train_data, self._train_targets = idata.train_data, idata.train_targets
self._test_data, self._test_targets = idata.test_data, idata.test_targets
self.use_path = idata.use_path
with open("./src/utils/classes.pkl", "rb") as f:
self._object_classes_data = pickle.load(f)
# Transforms
self._train_trsf = idata.train_trsf
self._test_trsf = idata.test_trsf
self._common_trsf = idata.common_trsf
# Order
order = [i for i in range(len(np.unique(self._train_targets)))]
if shuffle:
np.random.seed(seed)
order = np.random.permutation(len(order)).tolist()
else:
order = idata.class_order
self._class_order = order
logging.info(self._class_order)
# Map indices
self._train_targets = _map_new_class_index(
self._train_targets, self._class_order
)
self._test_targets = _map_new_class_index(self._test_targets, self._class_order)
def _select(self, x, y, low_range, high_range):
idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
return x[idxes], y[idxes]
class DummyDataset(Dataset):
def __init__(self, images, labels, trsf, classes, use_path=False, args=None):
assert len(images) == len(labels), "Data size error!"
self.images = images
self.labels = labels
self.trsf = trsf
self.use_path = use_path
self.classes = classes
self.dataset_path = args["data_path"]
self.topk_classes = args["topk_classes"] if args["topk_classes"] > 0 else 1
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.dataset_path, self.images[idx])
if self.use_path:
image = self.trsf(pil_loader(img_path))
else:
image = self.trsf(Image.fromarray(img_path))
label = self.labels[idx]
classes = self.classes[img_path.replace(self.dataset_path, "")][: self.topk_classes]
return classes, image, label
def _map_new_class_index(y, order):
return np.array(list(map(lambda x: order.index(x), y)))
def _get_idata(dataset_name, args=None):
name = dataset_name.lower()
if name == "cddb":
return CDDB_benchmark(args)
elif name == "truefake":
return TrueFake_benchmark(args)
else:
raise NotImplementedError("Unknown dataset {}.".format(dataset_name))
def pil_loader(path):
"""
Ref:
https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
"""
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")