|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Data utils for CIFAR-10 and CIFAR-100."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import copy
|
| import cPickle
|
| import os
|
| import augmentation_transforms
|
| import numpy as np
|
| import policies as found_policies
|
| import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
|
| class DataSet(object):
|
| """Dataset object that produces augmented training and eval data."""
|
|
|
| def __init__(self, hparams):
|
| self.hparams = hparams
|
| self.epochs = 0
|
| self.curr_train_index = 0
|
|
|
| all_labels = []
|
|
|
| self.good_policies = found_policies.good_policies()
|
|
|
|
|
| num_data_batches_to_load = 5
|
| total_batches_to_load = num_data_batches_to_load
|
| train_batches_to_load = total_batches_to_load
|
| assert hparams.train_size + hparams.validation_size <= 50000
|
| if hparams.eval_test:
|
| total_batches_to_load += 1
|
|
|
| total_dataset_size = 10000 * num_data_batches_to_load
|
| train_dataset_size = total_dataset_size
|
| if hparams.eval_test:
|
| total_dataset_size += 10000
|
|
|
| if hparams.dataset == 'cifar10':
|
| all_data = np.empty((total_batches_to_load, 10000, 3072), dtype=np.uint8)
|
| elif hparams.dataset == 'cifar100':
|
| assert num_data_batches_to_load == 5
|
| all_data = np.empty((1, 50000, 3072), dtype=np.uint8)
|
| if hparams.eval_test:
|
| test_data = np.empty((1, 10000, 3072), dtype=np.uint8)
|
| if hparams.dataset == 'cifar10':
|
| tf.logging.info('Cifar10')
|
| datafiles = [
|
| 'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4',
|
| 'data_batch_5']
|
|
|
| datafiles = datafiles[:train_batches_to_load]
|
| if hparams.eval_test:
|
| datafiles.append('test_batch')
|
| num_classes = 10
|
| elif hparams.dataset == 'cifar100':
|
| datafiles = ['train']
|
| if hparams.eval_test:
|
| datafiles.append('test')
|
| num_classes = 100
|
| else:
|
| raise NotImplementedError('Unimplemented dataset: ', hparams.dataset)
|
| if hparams.dataset != 'test':
|
| for file_num, f in enumerate(datafiles):
|
| d = unpickle(os.path.join(hparams.data_path, f))
|
| if f == 'test':
|
| test_data[0] = copy.deepcopy(d['data'])
|
| all_data = np.concatenate([all_data, test_data], axis=1)
|
| else:
|
| all_data[file_num] = copy.deepcopy(d['data'])
|
| if hparams.dataset == 'cifar10':
|
| labels = np.array(d['labels'])
|
| else:
|
| labels = np.array(d['fine_labels'])
|
| nsamples = len(labels)
|
| for idx in range(nsamples):
|
| all_labels.append(labels[idx])
|
|
|
| all_data = all_data.reshape(total_dataset_size, 3072)
|
| all_data = all_data.reshape(-1, 3, 32, 32)
|
| all_data = all_data.transpose(0, 2, 3, 1).copy()
|
| all_data = all_data / 255.0
|
| mean = augmentation_transforms.MEANS
|
| std = augmentation_transforms.STDS
|
| tf.logging.info('mean:{} std: {}'.format(mean, std))
|
|
|
| all_data = (all_data - mean) / std
|
| all_labels = np.eye(num_classes)[np.array(all_labels, dtype=np.int32)]
|
| assert len(all_data) == len(all_labels)
|
| tf.logging.info(
|
| 'In CIFAR10 loader, number of images: {}'.format(len(all_data)))
|
|
|
|
|
| if hparams.eval_test:
|
| self.test_images = all_data[train_dataset_size:]
|
| self.test_labels = all_labels[train_dataset_size:]
|
|
|
|
|
| all_data = all_data[:train_dataset_size]
|
| all_labels = all_labels[:train_dataset_size]
|
| np.random.seed(0)
|
| perm = np.arange(len(all_data))
|
| np.random.shuffle(perm)
|
| all_data = all_data[perm]
|
| all_labels = all_labels[perm]
|
|
|
|
|
| train_size, val_size = hparams.train_size, hparams.validation_size
|
| assert 50000 >= train_size + val_size
|
| self.train_images = all_data[:train_size]
|
| self.train_labels = all_labels[:train_size]
|
| self.val_images = all_data[train_size:train_size + val_size]
|
| self.val_labels = all_labels[train_size:train_size + val_size]
|
| self.num_train = self.train_images.shape[0]
|
|
|
| def next_batch(self):
|
| """Return the next minibatch of augmented data."""
|
| next_train_index = self.curr_train_index + self.hparams.batch_size
|
| if next_train_index > self.num_train:
|
|
|
| epoch = self.epochs + 1
|
| self.reset()
|
| self.epochs = epoch
|
| batched_data = (
|
| self.train_images[self.curr_train_index:
|
| self.curr_train_index + self.hparams.batch_size],
|
| self.train_labels[self.curr_train_index:
|
| self.curr_train_index + self.hparams.batch_size])
|
| final_imgs = []
|
|
|
| images, labels = batched_data
|
| for data in images:
|
| epoch_policy = self.good_policies[np.random.choice(
|
| len(self.good_policies))]
|
| final_img = augmentation_transforms.apply_policy(
|
| epoch_policy, data)
|
| final_img = augmentation_transforms.random_flip(
|
| augmentation_transforms.zero_pad_and_crop(final_img, 4))
|
|
|
| final_img = augmentation_transforms.cutout_numpy(final_img)
|
| final_imgs.append(final_img)
|
| batched_data = (np.array(final_imgs, np.float32), labels)
|
| self.curr_train_index += self.hparams.batch_size
|
| return batched_data
|
|
|
| def reset(self):
|
| """Reset training data and index into the training data."""
|
| self.epochs = 0
|
|
|
| perm = np.arange(self.num_train)
|
| np.random.shuffle(perm)
|
| assert self.num_train == self.train_images.shape[
|
| 0], 'Error incorrect shuffling mask'
|
| self.train_images = self.train_images[perm]
|
| self.train_labels = self.train_labels[perm]
|
| self.curr_train_index = 0
|
|
|
|
|
| def unpickle(f):
|
| tf.logging.info('loading file: {}'.format(f))
|
| fo = tf.gfile.Open(f, 'r')
|
| d = cPickle.load(fo)
|
| fo.close()
|
| return d
|
|
|