| |
| |
| import os |
| import gzip |
| import numpy as np |
| import io |
| from PIL import Image |
| from torch.utils.data import Dataset |
|
|
| try: |
| from PIL import UnidentifiedImageError |
|
|
| unidentified_error_available = True |
| except ImportError: |
| |
| unidentified_error_available = False |
|
|
| class DiskTarDataset(Dataset): |
| def __init__(self, |
| tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy', |
| tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy', |
| preload=False, |
| num_synsets="all"): |
| """ |
| - preload (bool): Recommend to set preload to False when using |
| - num_synsets (integer or string "all"): set to small number for debugging |
| will load subset of dataset |
| """ |
| tar_files = np.load(tarfile_path) |
|
|
| chunk_datasets = [] |
| dataset_lens = [] |
| if isinstance(num_synsets, int): |
| assert num_synsets < len(tar_files) |
| tar_files = tar_files[:num_synsets] |
| for tar_file in tar_files: |
| dataset = _TarDataset(tar_file, tar_index_dir, preload=preload) |
| chunk_datasets.append(dataset) |
| dataset_lens.append(len(dataset)) |
|
|
| self.chunk_datasets = chunk_datasets |
| self.dataset_lens = np.array(dataset_lens).astype(np.int32) |
| self.dataset_cumsums = np.cumsum(self.dataset_lens) |
| self.num_samples = sum(self.dataset_lens) |
| labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64) |
| sI = 0 |
| for k in range(len(self.dataset_lens)): |
| assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}" |
| labels[sI:(sI+self.dataset_lens[k])] = k |
| sI += self.dataset_lens[k] |
| self.labels = labels |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def __getitem__(self, index): |
| assert index >= 0 and index < len(self) |
| |
| d_index = np.searchsorted(self.dataset_cumsums, index) |
|
|
| |
| if index in self.dataset_cumsums: |
| d_index += 1 |
|
|
| assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}" |
|
|
| |
| if d_index == 0: |
| local_index = index |
| else: |
| local_index = index - self.dataset_cumsums[d_index - 1] |
| data_bytes = self.chunk_datasets[d_index][local_index] |
| exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception |
| try: |
| image = Image.open(data_bytes).convert("RGB") |
| except exception_to_catch: |
| image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128) |
| d_index = -1 |
|
|
| |
| return image, d_index, index |
|
|
| def __repr__(self): |
| st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})" |
| return st |
|
|
| class _TarDataset(object): |
|
|
| def __init__(self, filename, npy_index_dir, preload=False): |
| |
| |
| self.filename = filename |
| self.names = [] |
| self.offsets = [] |
| self.npy_index_dir = npy_index_dir |
| names, offsets = self.load_index() |
|
|
| self.num_samples = len(names) |
| if preload: |
| self.data = np.memmap(filename, mode='r', dtype='uint8') |
| self.offsets = offsets |
| else: |
| self.data = None |
|
|
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def load_index(self): |
| basename = os.path.basename(self.filename) |
| basename = os.path.splitext(basename)[0] |
| names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy")) |
| offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy")) |
| return names, offsets |
|
|
| def __getitem__(self, idx): |
| if self.data is None: |
| self.data = np.memmap(self.filename, mode='r', dtype='uint8') |
| _, self.offsets = self.load_index() |
|
|
| ofs = self.offsets[idx] * 512 |
| fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx]) |
| data = self.data[ofs:ofs + fsize] |
|
|
| if data[:13].tostring() == '././@LongLink': |
| data = data[3 * 512:] |
| else: |
| data = data[512:] |
|
|
| |
| |
| if tuple(data[:2]) == (0x1f, 0x8b): |
| s = io.BytesIO(data.tostring()) |
| g = gzip.GzipFile(None, 'r', 0, s) |
| sdata = g.read() |
| else: |
| sdata = data.tostring() |
| return io.BytesIO(sdata) |