| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Streaming images and labels from datasets created with dataset_tool.py.""" |
|
|
| import cv2 |
| import os |
| import numpy as np |
| import zipfile |
| import PIL.Image |
| import json |
| import torch |
| import dnnlib |
| from torchvision import transforms |
|
|
| from pdb import set_trace as st |
|
|
| from .shapenet import LMDBDataset_MV_Compressed, decompress_array |
|
|
| try: |
| import pyspng |
| except ImportError: |
| pyspng = None |
|
|
| |
|
|
|
|
| |
| def init_dataset_kwargs(data, |
| class_name='datasets.eg3d_dataset.ImageFolderDataset', |
| reso_gt=128): |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| dataset_kwargs = dnnlib.EasyDict(class_name=class_name, |
| reso_gt=reso_gt, |
| path=data, |
| use_labels=True, |
| max_size=None, |
| xflip=False) |
| dataset_obj = dnnlib.util.construct_class_by_name( |
| **dataset_kwargs) |
| dataset_kwargs.resolution = dataset_obj.resolution |
| dataset_kwargs.use_labels = dataset_obj.has_labels |
| dataset_kwargs.max_size = len( |
| dataset_obj) |
|
|
| return dataset_kwargs, dataset_obj.name |
| |
| |
|
|
|
|
| class Dataset(torch.utils.data.Dataset): |
|
|
| def __init__( |
| self, |
| name, |
| raw_shape, |
| reso_gt=128, |
| max_size=None, |
| use_labels=False, |
| xflip=False, |
| random_seed=0, |
| ): |
| self._name = name |
| self._raw_shape = list(raw_shape) |
| self._use_labels = use_labels |
| self._raw_labels = None |
| self._label_shape = None |
|
|
| |
| self.reso_gt = reso_gt |
| self.reso_encoder = 224 |
|
|
| |
| self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) |
| |
| if (max_size is not None) and (self._raw_idx.size > max_size): |
| np.random.RandomState(random_seed).shuffle(self._raw_idx) |
| self._raw_idx = np.sort(self._raw_idx[:max_size]) |
|
|
| |
| self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) |
| if xflip: |
| self._raw_idx = np.tile(self._raw_idx, 2) |
| self._xflip = np.concatenate( |
| [self._xflip, np.ones_like(self._xflip)]) |
|
|
| |
| self.normalize_for_encoder_input = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| transforms.Resize(size=(self.reso_encoder, self.reso_encoder), |
| antialias=True), |
| ]) |
|
|
| self.normalize_for_gt = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| transforms.Resize(size=(self.reso_gt, self.reso_gt), |
| antialias=True), |
| ]) |
|
|
| def _get_raw_labels(self): |
| if self._raw_labels is None: |
| self._raw_labels = self._load_raw_labels( |
| ) if self._use_labels else None |
| if self._raw_labels is None: |
| self._raw_labels = np.zeros([self._raw_shape[0], 0], |
| dtype=np.float32) |
| assert isinstance(self._raw_labels, np.ndarray) |
| |
| assert self._raw_labels.dtype in [np.float32, np.int64] |
| if self._raw_labels.dtype == np.int64: |
| assert self._raw_labels.ndim == 1 |
| assert np.all(self._raw_labels >= 0) |
| self._raw_labels_std = self._raw_labels.std(0) |
| return self._raw_labels |
|
|
| def close(self): |
| pass |
|
|
| def _load_raw_image(self, raw_idx): |
| raise NotImplementedError |
|
|
| def _load_raw_labels(self): |
| raise NotImplementedError |
|
|
| def __getstate__(self): |
| return dict(self.__dict__, _raw_labels=None) |
|
|
| def __del__(self): |
| try: |
| self.close() |
| except: |
| pass |
|
|
| def __len__(self): |
| return self._raw_idx.size |
| |
|
|
| def __getitem__(self, idx): |
| |
|
|
| matte = self._load_raw_matte(self._raw_idx[idx]) |
| assert isinstance(matte, np.ndarray) |
| assert list(matte.shape)[1:] == self.image_shape[1:] |
| if self._xflip[idx]: |
| assert matte.ndim == 1 |
| matte = matte[:, :, ::-1] |
| |
| matte_orig = matte.copy().astype(np.float32) |
| |
| matte = np.transpose(matte, |
| |
| (1, 2, 0)).astype(np.float32) |
| matte = cv2.resize(matte, (self.reso_gt, self.reso_gt), |
| interpolation=cv2.INTER_NEAREST) |
| assert matte.min() >= 0 and matte.max( |
| ) <= 1, f'{matte.min(), matte.max()}' |
|
|
| if matte.ndim == 3: |
| matte = matte[..., 0] |
|
|
| image = self._load_raw_image(self._raw_idx[idx]) |
|
|
| assert isinstance(image, np.ndarray) |
| assert list(image.shape) == self.image_shape |
| assert image.dtype == np.uint8 |
| if self._xflip[idx]: |
| assert image.ndim == 3 |
| image = image[:, :, ::-1] |
|
|
| |
| |
| blending = False |
| if blending: |
| image = image * matte_orig + (1 - matte_orig) * cv2.GaussianBlur( |
| image, (5, 5), cv2.BORDER_DEFAULT) |
| |
|
|
| image = np.transpose(image, (1, 2, 0)).astype( |
| np.float32 |
| ) / 255 |
|
|
| image_sr = torch.from_numpy(image)[..., :3].permute( |
| 2, 0, 1) * 2 - 1 |
| image_to_encoder = self.normalize_for_encoder_input(image) |
|
|
| image_gt = cv2.resize(image, (self.reso_gt, self.reso_gt), |
| interpolation=cv2.INTER_AREA) |
| image_gt = torch.from_numpy(image_gt)[..., :3].permute( |
| 2, 0, 1) * 2 - 1 |
|
|
| return dict( |
| c=self.get_label(idx), |
| img_to_encoder=image_to_encoder, |
| img_sr=image_sr, |
| img=image_gt, |
| |
| depth=matte, |
| depth_mask=matte, |
| |
| |
| ) |
|
|
| def get_label(self, idx): |
| label = self._get_raw_labels()[self._raw_idx[idx]] |
| if label.dtype == np.int64: |
| onehot = np.zeros(self.label_shape, dtype=np.float32) |
| onehot[label] = 1 |
| label = onehot |
| return label.copy() |
|
|
| def get_details(self, idx): |
| d = dnnlib.EasyDict() |
| d.raw_idx = int(self._raw_idx[idx]) |
| d.xflip = (int(self._xflip[idx]) != 0) |
| d.raw_label = self._get_raw_labels()[d.raw_idx].copy() |
| return d |
|
|
| def get_label_std(self): |
| return self._raw_labels_std |
|
|
| @property |
| def name(self): |
| return self._name |
|
|
| @property |
| def image_shape(self): |
| return list(self._raw_shape[1:]) |
|
|
| @property |
| def num_channels(self): |
| assert len(self.image_shape) == 3 |
| return self.image_shape[0] |
|
|
| @property |
| def resolution(self): |
| assert len(self.image_shape) == 3 |
| assert self.image_shape[1] == self.image_shape[2] |
| return self.image_shape[1] |
|
|
| @property |
| def label_shape(self): |
| if self._label_shape is None: |
| raw_labels = self._get_raw_labels() |
| if raw_labels.dtype == np.int64: |
| self._label_shape = [int(np.max(raw_labels)) + 1] |
| else: |
| self._label_shape = raw_labels.shape[1:] |
| return list(self._label_shape) |
|
|
| @property |
| def label_dim(self): |
| assert len(self.label_shape) == 1 |
| return self.label_shape[0] |
|
|
| @property |
| def has_labels(self): |
| return any(x != 0 for x in self.label_shape) |
|
|
| @property |
| def has_onehot_labels(self): |
| return self._get_raw_labels().dtype == np.int64 |
|
|
|
|
| |
|
|
|
|
| class ImageFolderDataset(Dataset): |
|
|
| def __init__( |
| self, |
| path, |
| resolution=None, |
| reso_gt=128, |
| **super_kwargs, |
| ): |
| self._path = path |
| |
| |
| self._matte_path = path.replace('unzipped_ffhq_512', |
| 'ffhq_512_seg') |
| self._zipfile = None |
|
|
| if os.path.isdir(self._path): |
| self._type = 'dir' |
| self._all_fnames = { |
| os.path.relpath(os.path.join(root, fname), start=self._path) |
| for root, _dirs, files in os.walk(self._path) |
| for fname in files |
| } |
| elif self._file_ext(self._path) == '.zip': |
| self._type = 'zip' |
| self._all_fnames = set(self._get_zipfile().namelist()) |
| else: |
| raise IOError('Path must point to a directory or zip') |
|
|
| PIL.Image.init() |
| self._image_fnames = sorted( |
| fname for fname in self._all_fnames |
| if self._file_ext(fname) in PIL.Image.EXTENSION) |
| if len(self._image_fnames) == 0: |
| raise IOError('No image files found in the specified path') |
|
|
| name = os.path.splitext(os.path.basename(self._path))[0] |
| raw_shape = [len(self._image_fnames)] + list( |
| self._load_raw_image(0).shape) |
| |
| |
| if resolution is not None and (raw_shape[2] != resolution |
| or raw_shape[3] != resolution): |
| raise IOError('Image files do not match the specified resolution') |
| super().__init__(name=name, |
| raw_shape=raw_shape, |
| reso_gt=reso_gt, |
| **super_kwargs) |
|
|
| @staticmethod |
| def _file_ext(fname): |
| return os.path.splitext(fname)[1].lower() |
|
|
| def _get_zipfile(self): |
| assert self._type == 'zip' |
| if self._zipfile is None: |
| self._zipfile = zipfile.ZipFile(self._path) |
| return self._zipfile |
|
|
| def _open_file(self, fname): |
| if self._type == 'dir': |
| return open(os.path.join(self._path, fname), 'rb') |
| if self._type == 'zip': |
| return self._get_zipfile().open(fname, 'r') |
| return None |
|
|
| def _open_matte_file(self, fname): |
| if self._type == 'dir': |
| return open(os.path.join(self._matte_path, fname), 'rb') |
| |
| |
| |
|
|
| def close(self): |
| try: |
| if self._zipfile is not None: |
| self._zipfile.close() |
| finally: |
| self._zipfile = None |
|
|
| def __getstate__(self): |
| return dict(super().__getstate__(), _zipfile=None) |
|
|
| def _load_raw_image(self, raw_idx): |
| fname = self._image_fnames[raw_idx] |
| with self._open_file(fname) as f: |
| if pyspng is not None and self._file_ext(fname) == '.png': |
| image = pyspng.load(f.read()) |
| else: |
| image = np.array(PIL.Image.open(f)) |
| if image.ndim == 2: |
| image = image[:, :, np.newaxis] |
| image = image.transpose(2, 0, 1) |
| return image |
|
|
| def _load_raw_matte(self, raw_idx): |
| |
| fname = self._image_fnames[raw_idx] |
| with self._open_matte_file(fname) as f: |
| if pyspng is not None and self._file_ext(fname) == '.png': |
| image = pyspng.load(f.read()) |
| else: |
| image = np.array(PIL.Image.open(f)) |
| |
| image = (image > 0).astype(np.float32) |
| if image.ndim == 2: |
| image = image[:, :, np.newaxis] |
| image = image.transpose(2, 0, 1) |
| return image |
|
|
| def _load_raw_matte_orig(self, raw_idx): |
| fname = self._image_fnames[raw_idx] |
| with self._open_matte_file(fname) as f: |
| if pyspng is not None and self._file_ext(fname) == '.png': |
| image = pyspng.load(f.read()) |
| else: |
| image = np.array(PIL.Image.open(f)) |
| st() |
| if image.ndim == 2: |
| image = image[:, :, np.newaxis] |
| image = image.transpose(2, 0, 1) |
| return image |
|
|
| def _load_raw_labels(self): |
| fname = 'dataset.json' |
| if fname not in self._all_fnames: |
| return None |
| with self._open_file(fname) as f: |
| |
| labels = json.load(f)['labels'] |
| if labels is None: |
| return None |
| labels = dict(labels) |
| labels_ = [] |
| for fname, _ in labels.items(): |
| |
| labels_.append(labels[fname]) |
| labels = labels_ |
| |
| |
| |
| |
| labels = np.array(labels) |
| labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) |
| self._raw_labels = labels |
| return labels |
|
|
|
|
| |
|
|
|
|
| |
|
|
| |
| |
|
|
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class ImageFolderDatasetLMDB(ImageFolderDataset): |
| def __init__(self, path, resolution=None, reso_gt=128, **super_kwargs): |
| super().__init__(path, resolution, reso_gt, **super_kwargs) |
| |
| def __getitem__(self, idx): |
| |
|
|
| matte = self._load_raw_matte(self._raw_idx[idx]) |
| assert isinstance(matte, np.ndarray) |
| assert list(matte.shape)[1:] == self.image_shape[1:] |
| if self._xflip[idx]: |
| assert matte.ndim == 1 |
| matte = matte[:, :, ::-1] |
| |
| matte_orig = matte.copy().astype(np.float32) |
| assert matte_orig.max() <= 1 |
| matte = np.transpose(matte, |
| |
| (1, 2, 0)).astype(np.float32) |
|
|
| |
| |
| |
|
|
| assert matte.min() >= 0 and matte.max( |
| ) <= 1, f'{matte.min(), matte.max()}' |
|
|
| if matte.ndim == 3: |
| matte = matte[..., 0] |
|
|
| image = self._load_raw_image(self._raw_idx[idx]) |
|
|
| assert isinstance(image, np.ndarray) |
| assert list(image.shape) == self.image_shape |
| assert image.dtype == np.uint8 |
| if self._xflip[idx]: |
| assert image.ndim == 3 |
| image = image[:, :, ::-1] |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| return dict( |
| c=self.get_label(idx), |
| |
| |
| img=image, |
| |
| |
| depth_mask=matte, |
| ) |
|
|
| class LMDBDataset_MV_Compressed_eg3d(LMDBDataset_MV_Compressed): |
|
|
| def __init__(self, |
| lmdb_path, |
| reso, |
| reso_encoder, |
| imgnet_normalize=True, |
| **kwargs): |
| super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, |
| **kwargs) |
|
|
| self.normalize_for_encoder_input = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| transforms.Resize(size=(self.reso_encoder, self.reso_encoder), |
| antialias=True), |
| ]) |
|
|
| self.normalize_for_gt = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| transforms.Resize(size=(self.reso, self.reso), |
| antialias=True), |
| ]) |
|
|
| def __getitem__(self, idx): |
| |
|
|
| |
| with self.env.begin(write=False) as txn: |
| img_key = f'{idx}-img'.encode('utf-8') |
| image = self.load_image_fn(txn.get(img_key)) |
|
|
| depth_key = f'{idx}-depth_mask'.encode('utf-8') |
| |
| depth = decompress_array(txn.get(depth_key), (64,64), np.float32) |
|
|
| c_key = f'{idx}-c'.encode('utf-8') |
| c = decompress_array(txn.get(c_key), (25, ), np.float32) |
|
|
| |
| depth = cv2.resize(depth, (self.reso, self.reso), |
| interpolation=cv2.INTER_NEAREST) |
|
|
| image = np.transpose(image, (1, 2, 0)).astype( |
| np.float32 |
| ) / 255 |
|
|
| image_sr = torch.from_numpy(image)[..., :3].permute( |
| 2, 0, 1) * 2 - 1 |
| image_to_encoder = self.normalize_for_encoder_input(image) |
|
|
| image_gt = cv2.resize(image, (self.reso, self.reso), |
| interpolation=cv2.INTER_AREA) |
| image_gt = torch.from_numpy(image_gt)[..., :3].permute( |
| 2, 0, 1) * 2 - 1 |
|
|
|
|
| return { |
| 'img_to_encoder': image_to_encoder, |
| 'img_sr': image_sr, |
| 'img': image_gt, |
| 'c': c, |
| 'depth': depth, |
| 'depth_mask': depth, |
| } |
|
|