import os import os.path as osp import tarfile import zipfile from collections import defaultdict import gdown import json import torch from torch.utils.data import Dataset as TorchDataset import torchvision.transforms as T from PIL import Image import numpy as np import torchvision.transforms as transforms from datasets.augmix_ops import augmentations def listdir_nohidden(path, sort=False): """List non-hidden items in a directory. Args: path (str): directory path. sort (bool): sort the items. """ items = [f for f in os.listdir(path) if not f.startswith(".")] if sort: items.sort() return items def read_json(fpath): """Read json file from a path.""" with open(fpath, 'r') as f: obj = json.load(f) return obj def write_json(obj, fpath): """Writes to a json file.""" if not osp.exists(osp.dirname(fpath)): os.makedirs(osp.dirname(fpath)) with open(fpath, 'w') as f: json.dump(obj, f, indent=4, separators=(',', ': ')) def read_image(path): """Read image from path using ``PIL.Image``. Args: path (str): path to an image. Returns: PIL image """ if not osp.exists(path): raise IOError('No file exists at {}'.format(path)) while True: try: img = Image.open(path).convert('RGB') return img except IOError: print( 'Cannot read image from {}, ' 'probably due to heavy IO. Will re-try'.format(path) ) def listdir_nohidden(path, sort=False): """List non-hidden items in a directory. Args: path (str): directory path. sort (bool): sort the items. """ items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] if sort: items.sort() return items class Datum: """Data instance which defines the basic attributes. Args: impath (str): image path. label (int): class label. domain (int): domain label. classname (str): class name. """ def __init__(self, impath='', label=0, domain=-1, classname=''): assert isinstance(impath, str) assert isinstance(label, int) assert isinstance(domain, int) assert isinstance(classname, str) self._impath = impath self._label = label self._domain = domain self._classname = classname @property def impath(self): return self._impath @property def label(self): return self._label @property def domain(self): return self._domain @property def classname(self): return self._classname class DatasetBase: """A unified dataset class for 1) domain adaptation 2) domain generalization 3) semi-supervised learning """ dataset_dir = '' # the directory where the dataset is stored domains = [] # string names of all domains def __init__(self, train_x=None, train_u=None, val=None, test=None): self._train_x = train_x # labeled training data self._train_u = train_u # unlabeled training data (optional) self._val = val # validation data (optional) self._test = test # test data self._num_classes = self.get_num_classes(test) self._lab2cname, self._classnames = self.get_lab2cname(test) @property def train_x(self): return self._train_x @property def train_u(self): return self._train_u @property def val(self): return self._val @property def test(self): return self._test @property def lab2cname(self): return self._lab2cname @property def classnames(self): return self._classnames @property def num_classes(self): return self._num_classes def get_num_classes(self, data_source): """Count number of classes. Args: data_source (list): a list of Datum objects. """ label_set = set() for item in data_source: label_set.add(item.label) return max(label_set) + 1 def get_lab2cname(self, data_source): """Get a label-to-classname mapping (dict). Args: data_source (list): a list of Datum objects. """ container = set() for item in data_source: container.add((item.label, item.classname)) mapping = {label: classname for label, classname in container} labels = list(mapping.keys()) labels.sort() classnames = [mapping[label] for label in labels] return mapping, classnames def check_input_domains(self, source_domains, target_domains): self.is_input_domain_valid(source_domains) self.is_input_domain_valid(target_domains) def is_input_domain_valid(self, input_domains): for domain in input_domains: if domain not in self.domains: raise ValueError( 'Input domain must belong to {}, ' 'but got [{}]'.format(self.domains, domain) ) def download_data(self, url, dst, from_gdrive=True): if not osp.exists(osp.dirname(dst)): os.makedirs(osp.dirname(dst)) if from_gdrive: gdown.download(url, dst, quiet=False) else: raise NotImplementedError print('Extracting file ...') try: tar = tarfile.open(dst) tar.extractall(path=osp.dirname(dst)) tar.close() except: zip_ref = zipfile.ZipFile(dst, 'r') zip_ref.extractall(osp.dirname(dst)) zip_ref.close() print('File extracted to {}'.format(osp.dirname(dst))) def split_dataset_by_label(self, data_source): """Split a dataset, i.e. a list of Datum objects, into class-specific groups stored in a dictionary. Args: data_source (list): a list of Datum objects. """ output = defaultdict(list) for item in data_source: output[item.label].append(item) return output def split_dataset_by_domain(self, data_source): """Split a dataset, i.e. a list of Datum objects, into domain-specific groups stored in a dictionary. Args: data_source (list): a list of Datum objects. """ output = defaultdict(list) for item in data_source: output[item.domain].append(item) return output class DatasetWrapper(TorchDataset): def __init__(self, data_source, input_size, transform=None, is_train=False, return_img0=False, k_tfm=1): self.data_source = data_source self.transform = transform # accept list (tuple) as input self.is_train = is_train # Augmenting an image K>1 times is only allowed during training self.k_tfm = k_tfm if is_train else 1 self.return_img0 = return_img0 if self.k_tfm > 1 and transform is None: raise ValueError( 'Cannot augment the image {} times ' 'because transform is None'.format(self.k_tfm) ) # Build transform that doesn't apply any data augmentation interp_mode = T.InterpolationMode.BICUBIC to_tensor = [] to_tensor += [T.Resize(input_size, interpolation=interp_mode)] to_tensor += [T.ToTensor()] normalize = T.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) ) to_tensor += [normalize] self.to_tensor = T.Compose(to_tensor) def __len__(self): return len(self.data_source) def __getitem__(self, idx): item = self.data_source[idx] output = { 'label': item.label, 'domain': item.domain, 'impath': item.impath } img0 = read_image(item.impath) if self.transform is not None: if isinstance(self.transform, (list, tuple)): for i, tfm in enumerate(self.transform): img = self._transform_image(tfm, img0) keyname = 'img' if (i + 1) > 1: keyname += str(i + 1) output[keyname] = img else: img = self._transform_image(self.transform, img0) output['img'] = img if self.return_img0: output['img0'] = self.to_tensor(img0) return output['img'], output['label'] def _transform_image(self, tfm, img0): img_list = [] for k in range(self.k_tfm): img_list.append(tfm(img0)) img = img_list if len(img) == 1: img = img[0] return img def build_data_loader( data_source=None, batch_size=64, input_size=224, tfm=None, is_train=True, shuffle=False, dataset_wrapper=None ): if dataset_wrapper is None: dataset_wrapper = DatasetWrapper # Build data loader data_loader = torch.utils.data.DataLoader( dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), batch_size=batch_size, num_workers=8, shuffle=shuffle, drop_last=False, pin_memory=(torch.cuda.is_available()) ) assert len(data_loader) > 0 return data_loader def get_preaugment(): return transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), ]) def augmix(image, preprocess, aug_list, severity=1): preaugment = get_preaugment() x_orig = preaugment(image) x_processed = preprocess(x_orig) if len(aug_list) == 0: return x_processed w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) m = np.float32(np.random.beta(1.0, 1.0)) mix = torch.zeros_like(x_processed) for i in range(3): x_aug = x_orig.copy() for _ in range(np.random.randint(1, 4)): x_aug = np.random.choice(aug_list)(x_aug, severity) mix += w[i] * preprocess(x_aug) mix = m * x_processed + (1 - m) * mix return mix class AugMixAugmenter(object): def __init__(self, base_transform, preprocess, n_views=2, augmix=False, severity=1): self.base_transform = base_transform self.preprocess = preprocess self.n_views = n_views if augmix: self.aug_list = augmentations else: self.aug_list = [] self.severity = severity def __call__(self, x): image = self.preprocess(self.base_transform(x)) views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] return [image] + views