English
File size: 1,324 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch.utils.data import DataLoader as TorchDataLoader


__all__ = ['DataLoader']

def __identity__(batch_list):
    """
        fix for windows, where lambda can't be pickled.
        We have to use a top level function
        see:
        https://discuss.pytorch.org/t/cant-pickle-local-object-dataloader-init-locals-lambda/31857/10?page=2
        https://docs.python.org/3/library/pickle.html#what-can-be-pickled-and-unpickled
    """
    return batch_list

class DataLoader(TorchDataLoader):
    """Same as torch DataLoader except that the default behaviour for
    `collate_fn=None` is a simple identity. (i.e. the DataLoader will
    return a list of elements by default). This approach is meant to
    move the CPU-hungry NAG.from_nag_list (in particular, the level-0
    Data.from_nag_list) to GPU. This is instead taken care of in the
    'DataModule.on_after_batch_transfer' hook, which calls the dataset
    'on_device_transform'.

    Use `collate_fn=NAG.from_data_list` if you want the CPU to do this
    operation (but beware of collisions with our
    'DataModule.on_after_batch_transfer' implementation.
    """
    def __init__(self, *args, collate_fn=None, **kwargs):
        if collate_fn is None:
            collate_fn = __identity__
        super().__init__(*args, collate_fn=collate_fn, **kwargs)