| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | from torch.utils.data import BatchSampler, DataLoader, IterableDataset |
| |
|
| | |
| | _PYTORCH_DATALOADER_KWARGS = { |
| | "batch_size": 1, |
| | "shuffle": False, |
| | "sampler": None, |
| | "batch_sampler": None, |
| | "num_workers": 0, |
| | "collate_fn": None, |
| | "pin_memory": False, |
| | "drop_last": False, |
| | "timeout": 0, |
| | "worker_init_fn": None, |
| | "multiprocessing_context": None, |
| | "generator": None, |
| | "prefetch_factor": 2, |
| | "persistent_workers": False, |
| | } |
| |
|
| |
|
| | class SkipBatchSampler(BatchSampler): |
| | """ |
| | A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. |
| | """ |
| |
|
| | def __init__(self, batch_sampler, skip_batches=0): |
| | self.batch_sampler = batch_sampler |
| | self.skip_batches = skip_batches |
| |
|
| | def __iter__(self): |
| | for index, samples in enumerate(self.batch_sampler): |
| | if index >= self.skip_batches: |
| | yield samples |
| |
|
| | @property |
| | def total_length(self): |
| | return len(self.batch_sampler) |
| |
|
| | def __len__(self): |
| | return len(self.batch_sampler) - self.skip_batches |
| |
|
| |
|
| | class SkipDataLoader(DataLoader): |
| | """ |
| | Subclass of a PyTorch `DataLoader` that will skip the first batches. |
| | |
| | Args: |
| | dataset (`torch.utils.data.dataset.Dataset`): |
| | The dataset to use to build this datalaoder. |
| | skip_batches (`int`, *optional*, defaults to 0): |
| | The number of batches to skip at the beginning. |
| | kwargs: |
| | All other keyword arguments to pass to the regular `DataLoader` initialization. |
| | """ |
| |
|
| | def __init__(self, dataset, skip_batches=0, **kwargs): |
| | super().__init__(dataset, **kwargs) |
| | self.skip_batches = skip_batches |
| |
|
| | def __iter__(self): |
| | for index, batch in enumerate(super().__iter__()): |
| | if index >= self.skip_batches: |
| | yield batch |
| |
|
| |
|
| | def skip_first_batches(dataloader, num_batches=0): |
| | """ |
| | Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. |
| | """ |
| | dataset = dataloader.dataset |
| | sampler_is_batch_sampler = False |
| | if isinstance(dataset, IterableDataset): |
| | new_batch_sampler = None |
| | else: |
| | sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) |
| | batch_sampler = ( |
| | dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler |
| | ) |
| | new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) |
| |
|
| | |
| | ignore_kwargs = [ |
| | "batch_size", |
| | "shuffle", |
| | "sampler", |
| | "batch_sampler", |
| | "drop_last", |
| | ] |
| |
|
| | kwargs = { |
| | k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) |
| | for k in _PYTORCH_DATALOADER_KWARGS |
| | if k not in ignore_kwargs |
| | } |
| |
|
| | |
| | if new_batch_sampler is None: |
| | kwargs["drop_last"] = dataloader.drop_last |
| | kwargs["batch_size"] = dataloader.batch_size |
| |
|
| | if new_batch_sampler is None: |
| | |
| | dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) |
| | else: |
| | dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) |
| |
|
| | return dataloader |
| |
|