English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
from torch.utils.data import DataLoader as TorchDataLoader
__all__ = ['DataLoader']
def __identity__(batch_list):
"""
fix for windows, where lambda can't be pickled.
We have to use a top level function
see:
https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857/10?page=2
https://docs.python.org/3/library/pickle.html#what-can-be-pickled-and-unpickled
"""
return batch_list
class DataLoader(TorchDataLoader):
"""Same as torch DataLoader except that the default behaviour for
`collate_fn=None` is a simple identity. (i.e. the DataLoader will
return a list of elements by default). This approach is meant to
move the CPU-hungry NAG.from_nag_list (in particular, the level-0
Data.from_nag_list) to GPU. This is instead taken care of in the
'DataModule.on_after_batch_transfer' hook, which calls the dataset
'on_device_transform'.
Use `collate_fn=NAG.from_data_list` if you want the CPU to do this
operation (but beware of collisions with our
'DataModule.on_after_batch_transfer' implementation.
"""
def __init__(self, *args, collate_fn=None, **kwargs):
if collate_fn is None:
collate_fn = __identity__
super().__init__(*args, collate_fn=collate_fn, **kwargs)