| import numbers |
| import os |
| import queue as Queue |
| import threading |
|
|
| import mxnet as mx |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
|
|
|
|
| class BackgroundGenerator(threading.Thread): |
| def __init__(self, generator, local_rank, max_prefetch=6): |
| super(BackgroundGenerator, self).__init__() |
| self.queue = Queue.Queue(max_prefetch) |
| self.generator = generator |
| self.local_rank = local_rank |
| self.daemon = True |
| self.start() |
|
|
| def run(self): |
| torch.cuda.set_device(self.local_rank) |
| for item in self.generator: |
| self.queue.put(item) |
| self.queue.put(None) |
|
|
| def next(self): |
| next_item = self.queue.get() |
| if next_item is None: |
| raise StopIteration |
| return next_item |
|
|
| def __next__(self): |
| return self.next() |
|
|
| def __iter__(self): |
| return self |
|
|
|
|
| class DataLoaderX(DataLoader): |
|
|
| def __init__(self, local_rank, **kwargs): |
| super(DataLoaderX, self).__init__(**kwargs) |
| self.stream = torch.cuda.Stream(local_rank) |
| self.local_rank = local_rank |
|
|
| def __iter__(self): |
| self.iter = super(DataLoaderX, self).__iter__() |
| self.iter = BackgroundGenerator(self.iter, self.local_rank) |
| self.preload() |
| return self |
|
|
| def preload(self): |
| self.batch = next(self.iter, None) |
| if self.batch is None: |
| return None |
| with torch.cuda.stream(self.stream): |
| for k in range(len(self.batch)): |
| self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) |
|
|
| def __next__(self): |
| torch.cuda.current_stream().wait_stream(self.stream) |
| batch = self.batch |
| if batch is None: |
| raise StopIteration |
| self.preload() |
| return batch |
|
|
|
|
| class MXFaceDataset(Dataset): |
| def __init__(self, root_dir, local_rank): |
| super(MXFaceDataset, self).__init__() |
| self.transform = transforms.Compose( |
| [transforms.ToPILImage(), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| ]) |
| self.root_dir = root_dir |
| self.local_rank = local_rank |
| path_imgrec = os.path.join(root_dir, 'train.rec') |
| path_imgidx = os.path.join(root_dir, 'train.idx') |
| self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') |
| s = self.imgrec.read_idx(0) |
| header, _ = mx.recordio.unpack(s) |
| if header.flag > 0: |
| self.header0 = (int(header.label[0]), int(header.label[1])) |
| self.imgidx = np.array(range(1, int(header.label[0]))) |
| else: |
| self.imgidx = np.array(list(self.imgrec.keys)) |
|
|
| def __getitem__(self, index): |
| idx = self.imgidx[index] |
| s = self.imgrec.read_idx(idx) |
| header, img = mx.recordio.unpack(s) |
| label = header.label |
| if not isinstance(label, numbers.Number): |
| label = label[0] |
| label = torch.tensor(label, dtype=torch.long) |
| sample = mx.image.imdecode(img).asnumpy() |
| if self.transform is not None: |
| sample = self.transform(sample) |
| return sample, label |
|
|
| def __len__(self): |
| return len(self.imgidx) |
|
|
|
|
| class SyntheticDataset(Dataset): |
| def __init__(self, local_rank): |
| super(SyntheticDataset, self).__init__() |
| img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) |
| img = np.transpose(img, (2, 0, 1)) |
| img = torch.from_numpy(img).squeeze(0).float() |
| img = ((img / 255) - 0.5) / 0.5 |
| self.img = img |
| self.label = 1 |
|
|
| def __getitem__(self, index): |
| return self.img, self.label |
|
|
| def __len__(self): |
| return 1000000 |
|
|