| from __future__ import division |
| import torch |
| from torch.utils.data import DataLoader |
| from torch.utils.data.sampler import Sampler |
|
|
|
|
| class RandomSampler(Sampler): |
| def __init__(self, data_source, checkpoint): |
| self.data_source = data_source |
| if checkpoint is not None and checkpoint['dataset_perm'] is not None: |
| self.dataset_perm = checkpoint['dataset_perm'] |
| self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] |
| else: |
| self.dataset_perm = torch.randperm(len(self.data_source)).tolist() |
| self.perm = torch.randperm(len(self.data_source)).tolist() |
|
|
| def __iter__(self): |
| return iter(self.perm) |
|
|
| def __len__(self): |
| return len(self.perm) |
|
|
|
|
| class SequentialSampler(Sampler): |
| def __init__(self, data_source, checkpoint): |
| self.data_source = data_source |
| if checkpoint is not None and checkpoint['dataset_perm'] is not None: |
| self.dataset_perm = checkpoint['dataset_perm'] |
| self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] |
| else: |
| self.dataset_perm = list(range(len(self.data_source))) |
| self.perm = self.dataset_perm |
|
|
| def __iter__(self): |
| return iter(self.perm) |
|
|
| def __len__(self): |
| return len(self.perm) |
|
|
|
|
| class CheckpointDataLoader(DataLoader): |
| """ |
| Extends torch.utils.data.DataLoader to handle resuming training from an arbitrary point within an epoch. |
| """ |
| def __init__( |
| self, |
| dataset, |
| checkpoint=None, |
| batch_size=1, |
| shuffle=False, |
| num_workers=0, |
| pin_memory=False, |
| drop_last=True, |
| timeout=0, |
| worker_init_fn=None |
| ): |
|
|
| if shuffle: |
| sampler = RandomSampler(dataset, checkpoint) |
| else: |
| sampler = SequentialSampler(dataset, checkpoint) |
| if checkpoint is not None: |
| self.checkpoint_batch_idx = checkpoint['batch_idx'] |
| else: |
| self.checkpoint_batch_idx = 0 |
|
|
| super(CheckpointDataLoader, self).__init__( |
| dataset, |
| sampler=sampler, |
| shuffle=False, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| drop_last=drop_last, |
| pin_memory=pin_memory, |
| timeout=timeout, |
| worker_init_fn=None |
| ) |
|
|