from torch.utils.data import DataLoader as TorchDataLoader __all__ = ['DataLoader'] def __identity__(batch_list): """ fix for windows, where lambda can't be pickled. We have to use a top level function see: https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857/10?page=2 https://docs.python.org/3/library/pickle.html#what-can-be-pickled-and-unpickled """ return batch_list class DataLoader(TorchDataLoader): """Same as torch DataLoader except that the default behaviour for `collate_fn=None` is a simple identity. (i.e. the DataLoader will return a list of elements by default). This approach is meant to move the CPU-hungry NAG.from_nag_list (in particular, the level-0 Data.from_nag_list) to GPU. This is instead taken care of in the 'DataModule.on_after_batch_transfer' hook, which calls the dataset 'on_device_transform'. Use `collate_fn=NAG.from_data_list` if you want the CPU to do this operation (but beware of collisions with our 'DataModule.on_after_batch_transfer' implementation. """ def __init__(self, *args, collate_fn=None, **kwargs): if collate_fn is None: collate_fn = __identity__ super().__init__(*args, collate_fn=collate_fn, **kwargs)