Spaces:
Sleeping
Sleeping
| import os | |
| import os.path as osp | |
| import tarfile | |
| import zipfile | |
| from collections import defaultdict | |
| import gdown | |
| import json | |
| import torch | |
| from torch.utils.data import Dataset as TorchDataset | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from datasets.augmix_ops import augmentations | |
| def listdir_nohidden(path, sort=False): | |
| """List non-hidden items in a directory. | |
| Args: | |
| path (str): directory path. | |
| sort (bool): sort the items. | |
| """ | |
| items = [f for f in os.listdir(path) if not f.startswith(".")] | |
| if sort: | |
| items.sort() | |
| return items | |
| def read_json(fpath): | |
| """Read json file from a path.""" | |
| with open(fpath, 'r') as f: | |
| obj = json.load(f) | |
| return obj | |
| def write_json(obj, fpath): | |
| """Writes to a json file.""" | |
| if not osp.exists(osp.dirname(fpath)): | |
| os.makedirs(osp.dirname(fpath)) | |
| with open(fpath, 'w') as f: | |
| json.dump(obj, f, indent=4, separators=(',', ': ')) | |
| def read_image(path): | |
| """Read image from path using ``PIL.Image``. | |
| Args: | |
| path (str): path to an image. | |
| Returns: | |
| PIL image | |
| """ | |
| if not osp.exists(path): | |
| raise IOError('No file exists at {}'.format(path)) | |
| while True: | |
| try: | |
| img = Image.open(path).convert('RGB') | |
| return img | |
| except IOError: | |
| print( | |
| 'Cannot read image from {}, ' | |
| 'probably due to heavy IO. Will re-try'.format(path) | |
| ) | |
| def listdir_nohidden(path, sort=False): | |
| """List non-hidden items in a directory. | |
| Args: | |
| path (str): directory path. | |
| sort (bool): sort the items. | |
| """ | |
| items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f] | |
| if sort: | |
| items.sort() | |
| return items | |
| class Datum: | |
| """Data instance which defines the basic attributes. | |
| Args: | |
| impath (str): image path. | |
| label (int): class label. | |
| domain (int): domain label. | |
| classname (str): class name. | |
| """ | |
| def __init__(self, impath='', label=0, domain=-1, classname=''): | |
| assert isinstance(impath, str) | |
| assert isinstance(label, int) | |
| assert isinstance(domain, int) | |
| assert isinstance(classname, str) | |
| self._impath = impath | |
| self._label = label | |
| self._domain = domain | |
| self._classname = classname | |
| def impath(self): | |
| return self._impath | |
| def label(self): | |
| return self._label | |
| def domain(self): | |
| return self._domain | |
| def classname(self): | |
| return self._classname | |
| class DatasetBase: | |
| """A unified dataset class for | |
| 1) domain adaptation | |
| 2) domain generalization | |
| 3) semi-supervised learning | |
| """ | |
| dataset_dir = '' # the directory where the dataset is stored | |
| domains = [] # string names of all domains | |
| def __init__(self, train_x=None, train_u=None, val=None, test=None): | |
| self._train_x = train_x # labeled training data | |
| self._train_u = train_u # unlabeled training data (optional) | |
| self._val = val # validation data (optional) | |
| self._test = test # test data | |
| self._num_classes = self.get_num_classes(test) | |
| self._lab2cname, self._classnames = self.get_lab2cname(test) | |
| def train_x(self): | |
| return self._train_x | |
| def train_u(self): | |
| return self._train_u | |
| def val(self): | |
| return self._val | |
| def test(self): | |
| return self._test | |
| def lab2cname(self): | |
| return self._lab2cname | |
| def classnames(self): | |
| return self._classnames | |
| def num_classes(self): | |
| return self._num_classes | |
| def get_num_classes(self, data_source): | |
| """Count number of classes. | |
| Args: | |
| data_source (list): a list of Datum objects. | |
| """ | |
| label_set = set() | |
| for item in data_source: | |
| label_set.add(item.label) | |
| return max(label_set) + 1 | |
| def get_lab2cname(self, data_source): | |
| """Get a label-to-classname mapping (dict). | |
| Args: | |
| data_source (list): a list of Datum objects. | |
| """ | |
| container = set() | |
| for item in data_source: | |
| container.add((item.label, item.classname)) | |
| mapping = {label: classname for label, classname in container} | |
| labels = list(mapping.keys()) | |
| labels.sort() | |
| classnames = [mapping[label] for label in labels] | |
| return mapping, classnames | |
| def check_input_domains(self, source_domains, target_domains): | |
| self.is_input_domain_valid(source_domains) | |
| self.is_input_domain_valid(target_domains) | |
| def is_input_domain_valid(self, input_domains): | |
| for domain in input_domains: | |
| if domain not in self.domains: | |
| raise ValueError( | |
| 'Input domain must belong to {}, ' | |
| 'but got [{}]'.format(self.domains, domain) | |
| ) | |
| def download_data(self, url, dst, from_gdrive=True): | |
| if not osp.exists(osp.dirname(dst)): | |
| os.makedirs(osp.dirname(dst)) | |
| if from_gdrive: | |
| gdown.download(url, dst, quiet=False) | |
| else: | |
| raise NotImplementedError | |
| print('Extracting file ...') | |
| try: | |
| tar = tarfile.open(dst) | |
| tar.extractall(path=osp.dirname(dst)) | |
| tar.close() | |
| except: | |
| zip_ref = zipfile.ZipFile(dst, 'r') | |
| zip_ref.extractall(osp.dirname(dst)) | |
| zip_ref.close() | |
| print('File extracted to {}'.format(osp.dirname(dst))) | |
| def split_dataset_by_label(self, data_source): | |
| """Split a dataset, i.e. a list of Datum objects, | |
| into class-specific groups stored in a dictionary. | |
| Args: | |
| data_source (list): a list of Datum objects. | |
| """ | |
| output = defaultdict(list) | |
| for item in data_source: | |
| output[item.label].append(item) | |
| return output | |
| def split_dataset_by_domain(self, data_source): | |
| """Split a dataset, i.e. a list of Datum objects, | |
| into domain-specific groups stored in a dictionary. | |
| Args: | |
| data_source (list): a list of Datum objects. | |
| """ | |
| output = defaultdict(list) | |
| for item in data_source: | |
| output[item.domain].append(item) | |
| return output | |
| class DatasetWrapper(TorchDataset): | |
| def __init__(self, data_source, input_size, transform=None, is_train=False, | |
| return_img0=False, k_tfm=1): | |
| self.data_source = data_source | |
| self.transform = transform # accept list (tuple) as input | |
| self.is_train = is_train | |
| # Augmenting an image K>1 times is only allowed during training | |
| self.k_tfm = k_tfm if is_train else 1 | |
| self.return_img0 = return_img0 | |
| if self.k_tfm > 1 and transform is None: | |
| raise ValueError( | |
| 'Cannot augment the image {} times ' | |
| 'because transform is None'.format(self.k_tfm) | |
| ) | |
| # Build transform that doesn't apply any data augmentation | |
| interp_mode = T.InterpolationMode.BICUBIC | |
| to_tensor = [] | |
| to_tensor += [T.Resize(input_size, interpolation=interp_mode)] | |
| to_tensor += [T.ToTensor()] | |
| normalize = T.Normalize( | |
| mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) | |
| ) | |
| to_tensor += [normalize] | |
| self.to_tensor = T.Compose(to_tensor) | |
| def __len__(self): | |
| return len(self.data_source) | |
| def __getitem__(self, idx): | |
| item = self.data_source[idx] | |
| output = { | |
| 'label': item.label, | |
| 'domain': item.domain, | |
| 'impath': item.impath | |
| } | |
| img0 = read_image(item.impath) | |
| if self.transform is not None: | |
| if isinstance(self.transform, (list, tuple)): | |
| for i, tfm in enumerate(self.transform): | |
| img = self._transform_image(tfm, img0) | |
| keyname = 'img' | |
| if (i + 1) > 1: | |
| keyname += str(i + 1) | |
| output[keyname] = img | |
| else: | |
| img = self._transform_image(self.transform, img0) | |
| output['img'] = img | |
| if self.return_img0: | |
| output['img0'] = self.to_tensor(img0) | |
| return output['img'], output['label'] | |
| def _transform_image(self, tfm, img0): | |
| img_list = [] | |
| for k in range(self.k_tfm): | |
| img_list.append(tfm(img0)) | |
| img = img_list | |
| if len(img) == 1: | |
| img = img[0] | |
| return img | |
| def build_data_loader( | |
| data_source=None, | |
| batch_size=64, | |
| input_size=224, | |
| tfm=None, | |
| is_train=True, | |
| shuffle=False, | |
| dataset_wrapper=None | |
| ): | |
| if dataset_wrapper is None: | |
| dataset_wrapper = DatasetWrapper | |
| # Build data loader | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train), | |
| batch_size=batch_size, | |
| num_workers=8, | |
| shuffle=shuffle, | |
| drop_last=False, | |
| pin_memory=(torch.cuda.is_available()) | |
| ) | |
| assert len(data_loader) > 0 | |
| return data_loader | |
| def get_preaugment(): | |
| return transforms.Compose([ | |
| transforms.RandomResizedCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| ]) | |
| def augmix(image, preprocess, aug_list, severity=1): | |
| preaugment = get_preaugment() | |
| x_orig = preaugment(image) | |
| x_processed = preprocess(x_orig) | |
| if len(aug_list) == 0: | |
| return x_processed | |
| w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) | |
| m = np.float32(np.random.beta(1.0, 1.0)) | |
| mix = torch.zeros_like(x_processed) | |
| for i in range(3): | |
| x_aug = x_orig.copy() | |
| for _ in range(np.random.randint(1, 4)): | |
| x_aug = np.random.choice(aug_list)(x_aug, severity) | |
| mix += w[i] * preprocess(x_aug) | |
| mix = m * x_processed + (1 - m) * mix | |
| return mix | |
| class AugMixAugmenter(object): | |
| def __init__(self, base_transform, preprocess, n_views=2, augmix=False, | |
| severity=1): | |
| self.base_transform = base_transform | |
| self.preprocess = preprocess | |
| self.n_views = n_views | |
| if augmix: | |
| self.aug_list = augmentations | |
| else: | |
| self.aug_list = [] | |
| self.severity = severity | |
| def __call__(self, x): | |
| image = self.preprocess(self.base_transform(x)) | |
| views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)] | |
| return [image] + views |