Spaces:
Sleeping
Sleeping
| import logging | |
| import numpy as np | |
| from PIL import Image, ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from utils.data import CDDB_benchmark, TrueFake_benchmark | |
| import pickle | |
| import os | |
| class DataManager(object): | |
| def __init__(self, dataset_name, shuffle, seed, init_cls, increment, args): | |
| self.args = args | |
| self.dataset_name = dataset_name | |
| self._setup_data(dataset_name, shuffle, seed) | |
| assert init_cls <= len(self._class_order), "No enough classes." | |
| self._increments = [init_cls] | |
| while sum(self._increments) + increment < len(self._class_order): | |
| self._increments.append(increment) | |
| offset = len(self._class_order) - sum(self._increments) | |
| if offset > 0: | |
| self._increments.append(offset) | |
| def nb_tasks(self): | |
| return len(self._increments) | |
| def get_task_size(self, task): | |
| return self._increments[task] | |
| def get_dataset(self, indices, source, mode, appendent=None, ret_data=False): | |
| if source == "train": | |
| x, y = self._train_data, self._train_targets | |
| elif source == "test": | |
| x, y = self._test_data, self._test_targets | |
| else: | |
| raise ValueError("Unknown data source {}.".format(source)) | |
| if mode == "train": | |
| trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) | |
| elif mode == "flip": | |
| trsf = transforms.Compose( | |
| [ | |
| *self._test_trsf, | |
| transforms.RandomHorizontalFlip(p=1.0), | |
| *self._common_trsf, | |
| ] | |
| ) | |
| elif mode == "test": | |
| trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) | |
| else: | |
| raise ValueError("Unknown mode {}.".format(mode)) | |
| data, targets = [], [] | |
| for idx in indices: | |
| class_data, class_targets = self._select( | |
| x, y, low_range=idx, high_range=idx + 1 | |
| ) | |
| data.append(class_data) | |
| targets.append(class_targets) | |
| if appendent is not None and len(appendent) != 0: | |
| appendent_data, appendent_targets = appendent | |
| data.append(appendent_data) | |
| targets.append(appendent_targets) | |
| data, targets = np.concatenate(data), np.concatenate(targets) | |
| # if ret_data: | |
| # return data, targets, DummyDataset(data, targets, trsf, self.use_path) | |
| # else: | |
| return DummyDataset( | |
| data, | |
| targets, | |
| trsf, | |
| self._object_classes_data, | |
| self.use_path, | |
| self.args, | |
| ) | |
| def get_anchor_dataset(self, mode, appendent=None, ret_data=False): | |
| if mode == "train": | |
| trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) | |
| elif mode == "flip": | |
| trsf = transforms.Compose( | |
| [ | |
| *self._test_trsf, | |
| transforms.RandomHorizontalFlip(p=1.0), | |
| *self._common_trsf, | |
| ] | |
| ) | |
| elif mode == "test": | |
| trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) | |
| else: | |
| raise ValueError("Unknown mode {}.".format(mode)) | |
| data, targets = [], [] | |
| if appendent is not None and len(appendent) != 0: | |
| appendent_data, appendent_targets = appendent | |
| data.append(appendent_data) | |
| targets.append(appendent_targets) | |
| data, targets = np.concatenate(data), np.concatenate(targets) | |
| if ret_data: | |
| return data, targets, DummyDataset(data, targets, trsf, self.use_path) | |
| else: | |
| return DummyDataset(data, targets, trsf, self.use_path) | |
| def get_dataset_with_split( | |
| self, indices, source, mode, appendent=None, val_samples_per_class=0 | |
| ): | |
| if source == "train": | |
| x, y = self._train_data, self._train_targets | |
| elif source == "test": | |
| x, y = self._test_data, self._test_targets | |
| else: | |
| raise ValueError("Unknown data source {}.".format(source)) | |
| if mode == "train": | |
| trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) | |
| elif mode == "test": | |
| trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) | |
| else: | |
| raise ValueError("Unknown mode {}.".format(mode)) | |
| train_data, train_targets = [], [] | |
| val_data, val_targets = [], [] | |
| for idx in indices: | |
| class_data, class_targets = self._select( | |
| x, y, low_range=idx, high_range=idx + 1 | |
| ) | |
| val_indx = np.random.choice( | |
| len(class_data), val_samples_per_class, replace=False | |
| ) | |
| train_indx = list(set(np.arange(len(class_data))) - set(val_indx)) | |
| val_data.append(class_data[val_indx]) | |
| val_targets.append(class_targets[val_indx]) | |
| train_data.append(class_data[train_indx]) | |
| train_targets.append(class_targets[train_indx]) | |
| if appendent is not None: | |
| appendent_data, appendent_targets = appendent | |
| for idx in range(0, int(np.max(appendent_targets)) + 1): | |
| append_data, append_targets = self._select( | |
| appendent_data, appendent_targets, low_range=idx, high_range=idx + 1 | |
| ) | |
| val_indx = np.random.choice( | |
| len(append_data), val_samples_per_class, replace=False | |
| ) | |
| train_indx = list(set(np.arange(len(append_data))) - set(val_indx)) | |
| val_data.append(append_data[val_indx]) | |
| val_targets.append(append_targets[val_indx]) | |
| train_data.append(append_data[train_indx]) | |
| train_targets.append(append_targets[train_indx]) | |
| train_data, train_targets = np.concatenate(train_data), np.concatenate( | |
| train_targets | |
| ) | |
| val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets) | |
| return DummyDataset( | |
| train_data, train_targets, trsf, self.use_path | |
| ), DummyDataset(val_data, val_targets, trsf, self.use_path) | |
| def _setup_data(self, dataset_name, shuffle, seed): | |
| idata = _get_idata(dataset_name, self.args) | |
| idata.download_data() | |
| # Data | |
| self._train_data, self._train_targets = idata.train_data, idata.train_targets | |
| self._test_data, self._test_targets = idata.test_data, idata.test_targets | |
| self.use_path = idata.use_path | |
| with open("./src/utils/classes.pkl", "rb") as f: | |
| self._object_classes_data = pickle.load(f) | |
| # Transforms | |
| self._train_trsf = idata.train_trsf | |
| self._test_trsf = idata.test_trsf | |
| self._common_trsf = idata.common_trsf | |
| # Order | |
| order = [i for i in range(len(np.unique(self._train_targets)))] | |
| if shuffle: | |
| np.random.seed(seed) | |
| order = np.random.permutation(len(order)).tolist() | |
| else: | |
| order = idata.class_order | |
| self._class_order = order | |
| logging.info(self._class_order) | |
| # Map indices | |
| self._train_targets = _map_new_class_index( | |
| self._train_targets, self._class_order | |
| ) | |
| self._test_targets = _map_new_class_index(self._test_targets, self._class_order) | |
| def _select(self, x, y, low_range, high_range): | |
| idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] | |
| return x[idxes], y[idxes] | |
| class DummyDataset(Dataset): | |
| def __init__(self, images, labels, trsf, classes, use_path=False, args=None): | |
| assert len(images) == len(labels), "Data size error!" | |
| self.images = images | |
| self.labels = labels | |
| self.trsf = trsf | |
| self.use_path = use_path | |
| self.classes = classes | |
| self.dataset_path = args["data_path"] | |
| self.topk_classes = args["topk_classes"] if args["topk_classes"] > 0 else 1 | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| img_path = os.path.join(self.dataset_path, self.images[idx]) | |
| if self.use_path: | |
| image = self.trsf(pil_loader(img_path)) | |
| else: | |
| image = self.trsf(Image.fromarray(img_path)) | |
| label = self.labels[idx] | |
| classes = self.classes[img_path.replace(self.dataset_path, "")][: self.topk_classes] | |
| return classes, image, label | |
| def _map_new_class_index(y, order): | |
| return np.array(list(map(lambda x: order.index(x), y))) | |
| def _get_idata(dataset_name, args=None): | |
| name = dataset_name.lower() | |
| if name == "cddb": | |
| return CDDB_benchmark(args) | |
| elif name == "truefake": | |
| return TrueFake_benchmark(args) | |
| else: | |
| raise NotImplementedError("Unknown dataset {}.".format(dataset_name)) | |
| def pil_loader(path): | |
| """ | |
| Ref: | |
| https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder | |
| """ | |
| # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
| with open(path, "rb") as f: | |
| img = Image.open(f) | |
| return img.convert("RGB") | |