| | import random |
| | import torch |
| | import torch.multiprocessing as multiprocessing |
| | from torch._C import _set_worker_signal_handlers, \ |
| | _remove_worker_pids, _error_if_any_worker_fails |
| |
|
| | from packaging import version |
| |
|
| | if version.Version(torch.__version__) >= version.Version('1.0.0'): |
| | from torch._C import _set_worker_pids |
| | else: |
| | from torch._C import _update_worker_pids as _set_worker_pids |
| |
|
| | from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler, Sampler |
| | import signal |
| | import functools |
| | import collections.abc |
| | import re |
| | import sys |
| | import threading |
| | import traceback |
| | import os |
| | import time |
| | |
| | string_classes = str |
| |
|
| | IS_WINDOWS = sys.platform == "win32" |
| | if IS_WINDOWS: |
| | import ctypes |
| | from ctypes.wintypes import DWORD, BOOL, HANDLE |
| |
|
| | if sys.version_info[0] == 2: |
| | import Queue as queue |
| | else: |
| | import queue |
| |
|
| | __all__ = ['SequentialDataLoader'] |
| |
|
| | class ExceptionWrapper(object): |
| | r"""Wraps an exception plus traceback to communicate across threads""" |
| |
|
| | def __init__(self, exc_info): |
| | self.exc_type = exc_info[0] |
| | self.exc_msg = "".join(traceback.format_exception(*exc_info)) |
| |
|
| |
|
| | _use_shared_memory = False |
| | r"""Whether to use shared memory in default_collate""" |
| |
|
| | MANAGER_STATUS_CHECK_INTERVAL = 5.0 |
| |
|
| | if IS_WINDOWS: |
| | |
| | |
| | |
| | class ManagerWatchdog(object): |
| | def __init__(self): |
| | self.manager_pid = os.getppid() |
| |
|
| | self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) |
| | self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) |
| | self.kernel32.OpenProcess.restype = HANDLE |
| | self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) |
| | self.kernel32.WaitForSingleObject.restype = DWORD |
| |
|
| | |
| | SYNCHRONIZE = 0x00100000 |
| | self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) |
| |
|
| | if not self.manager_handle: |
| | raise ctypes.WinError(ctypes.get_last_error()) |
| |
|
| | def is_alive(self): |
| | |
| | return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0 |
| | else: |
| | class ManagerWatchdog(object): |
| | def __init__(self): |
| | self.manager_pid = os.getppid() |
| |
|
| | def is_alive(self): |
| | return os.getppid() == self.manager_pid |
| |
|
| |
|
| | def _worker_loop(dataset, index_queue, data_queue, collate_fn, init_fn, worker_id): |
| | global _use_shared_memory |
| | _use_shared_memory = True |
| |
|
| | |
| | |
| | |
| | |
| | _set_worker_signal_handlers() |
| |
|
| | torch.set_num_threads(1) |
| |
|
| | if init_fn is not None: |
| | init_fn(worker_id) |
| |
|
| | watchdog = ManagerWatchdog() |
| |
|
| | while True: |
| | try: |
| | r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) |
| | except queue.Empty: |
| | if watchdog.is_alive(): |
| | continue |
| | else: |
| | break |
| | if r is None: |
| | break |
| | idx, batch_indices = r |
| | try: |
| | samples = collate_fn([dataset[i] for i in batch_indices]) |
| | except Exception: |
| | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) |
| | else: |
| | data_queue.put((idx, samples)) |
| | del samples |
| |
|
| |
|
| | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): |
| | if pin_memory: |
| | torch.cuda.set_device(device_id) |
| |
|
| | while True: |
| | try: |
| | r = in_queue.get() |
| | except Exception: |
| | if done_event.is_set(): |
| | return |
| | raise |
| | if r is None: |
| | break |
| | if isinstance(r[1], ExceptionWrapper): |
| | out_queue.put(r) |
| | continue |
| | idx, batch = r |
| | try: |
| | if pin_memory: |
| | batch = pin_memory_batch(batch) |
| | except Exception: |
| | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) |
| | else: |
| | out_queue.put((idx, batch)) |
| |
|
| | numpy_type_map = { |
| | 'float64': torch.DoubleTensor, |
| | 'float32': torch.FloatTensor, |
| | 'float16': torch.HalfTensor, |
| | 'int64': torch.LongTensor, |
| | 'int32': torch.IntTensor, |
| | 'int16': torch.ShortTensor, |
| | 'int8': torch.CharTensor, |
| | 'uint8': torch.ByteTensor, |
| | } |
| |
|
| |
|
| | def default_collate(batch): |
| | r"""Puts each data field into a tensor with outer dimension batch size""" |
| |
|
| | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" |
| | elem_type = type(batch[0]) |
| | if isinstance(batch[0], torch.Tensor): |
| | out = None |
| | if _use_shared_memory: |
| | |
| | |
| | numel = sum([x.numel() for x in batch]) |
| | storage = batch[0].storage()._new_shared(numel) |
| | out = batch[0].new(storage) |
| | return torch.stack(batch, 0, out=out) |
| | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ |
| | and elem_type.__name__ != 'string_': |
| | elem = batch[0] |
| | if elem_type.__name__ == 'ndarray': |
| | |
| | if re.search('[SaUO]', elem.dtype.str) is not None: |
| | raise TypeError(error_msg.format(elem.dtype)) |
| |
|
| | return torch.stack([torch.from_numpy(b) for b in batch], 0) |
| | if elem.shape == (): |
| | py_type = float if elem.dtype.name.startswith('float') else int |
| | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) |
| | elif isinstance(batch[0], int): |
| | return torch.LongTensor(batch) |
| | elif isinstance(batch[0], float): |
| | return torch.DoubleTensor(batch) |
| | elif isinstance(batch[0], string_classes): |
| | return batch |
| | elif isinstance(batch[0], collections.abc.Mapping): |
| | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} |
| | elif isinstance(batch[0], collections.abc.Sequence): |
| | transposed = zip(*batch) |
| | return [default_collate(samples) for samples in transposed] |
| |
|
| | raise TypeError((error_msg.format(type(batch[0])))) |
| |
|
| |
|
| | def pin_memory_batch(batch): |
| | if isinstance(batch, torch.Tensor): |
| | return batch.pin_memory() |
| | elif isinstance(batch, string_classes): |
| | return batch |
| | elif isinstance(batch, collections.abc.Mapping): |
| | return {k: pin_memory_batch(sample) for k, sample in batch.items()} |
| | elif isinstance(batch, collections.abc.Sequence): |
| | return [pin_memory_batch(sample) for sample in batch] |
| | else: |
| | return batch |
| |
|
| |
|
| | _SIGCHLD_handler_set = False |
| | r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one |
| | handler needs to be set for all DataLoaders in a process.""" |
| |
|
| |
|
| | def _set_SIGCHLD_handler(): |
| | |
| | if sys.platform == 'win32': |
| | return |
| | |
| | if not isinstance(threading.current_thread(), threading._MainThread): |
| | return |
| | global _SIGCHLD_handler_set |
| | if _SIGCHLD_handler_set: |
| | return |
| | previous_handler = signal.getsignal(signal.SIGCHLD) |
| | if not callable(previous_handler): |
| | previous_handler = None |
| |
|
| | def handler(signum, frame): |
| | |
| | |
| | _error_if_any_worker_fails() |
| | if previous_handler is not None: |
| | previous_handler(signum, frame) |
| |
|
| | signal.signal(signal.SIGCHLD, handler) |
| | _SIGCHLD_handler_set = True |
| |
|
| |
|
| | class _SequentialDataLoaderIter(object): |
| | r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" |
| |
|
| | def __init__(self, loader): |
| | self.dataset = loader.dataset |
| | self.collate_fn = loader.collate_fn |
| | self.batch_sampler = loader.batch_sampler |
| | self.num_workers = loader.num_workers |
| | self.pin_memory = loader.pin_memory and torch.cuda.is_available() |
| | self.timeout = loader.timeout |
| | self.done_event = threading.Event() |
| |
|
| | self.sample_iter = iter(self.batch_sampler) |
| |
|
| | if self.num_workers > 0: |
| | self.worker_init_fn = loader.worker_init_fn |
| | self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)] |
| | self.worker_queue_idx = 0 |
| | self.worker_result_queue = multiprocessing.SimpleQueue() |
| | self.batches_outstanding = 0 |
| | self.worker_pids_set = False |
| | self.shutdown = False |
| | self.send_idx = 0 |
| | self.rcvd_idx = 0 |
| | self.reorder_dict = {} |
| |
|
| | self.workers = [ |
| | multiprocessing.Process( |
| | target=_worker_loop, |
| | args=(self.dataset, self.index_queues[i], |
| | self.worker_result_queue, self.collate_fn, self.worker_init_fn, i)) |
| | for i in range(self.num_workers)] |
| |
|
| | if self.pin_memory or self.timeout > 0: |
| | self.data_queue = queue.Queue() |
| | if self.pin_memory: |
| | maybe_device_id = torch.cuda.current_device() |
| | else: |
| | |
| | maybe_device_id = None |
| | self.worker_manager_thread = threading.Thread( |
| | target=_worker_manager_loop, |
| | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, |
| | maybe_device_id)) |
| | self.worker_manager_thread.daemon = True |
| | self.worker_manager_thread.start() |
| | else: |
| | self.data_queue = self.worker_result_queue |
| |
|
| | for w in self.workers: |
| | w.daemon = True |
| | w.start() |
| |
|
| | _set_worker_pids(id(self), tuple(w.pid for w in self.workers)) |
| | _set_SIGCHLD_handler() |
| | self.worker_pids_set = True |
| |
|
| | |
| | for _ in range(2 * self.num_workers): |
| | self._put_indices() |
| |
|
| | def __len__(self): |
| | return len(self.batch_sampler) |
| |
|
| | def _get_batch(self): |
| | if self.timeout > 0: |
| | try: |
| | return self.data_queue.get(timeout=self.timeout) |
| | except queue.Empty: |
| | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) |
| | else: |
| | return self.data_queue.get() |
| |
|
| | def __next__(self): |
| | if self.num_workers == 0: |
| | indices = next(self.sample_iter) |
| | batch = self.collate_fn([self.dataset[i] for i in indices]) |
| | if self.pin_memory: |
| | batch = pin_memory_batch(batch) |
| | return batch |
| |
|
| | |
| | if self.rcvd_idx in self.reorder_dict: |
| | batch = self.reorder_dict.pop(self.rcvd_idx) |
| | return self._process_next_batch(batch) |
| |
|
| | if self.batches_outstanding == 0: |
| | self._shutdown_workers() |
| | raise StopIteration |
| |
|
| | while True: |
| | assert (not self.shutdown and self.batches_outstanding > 0) |
| | idx, batch = self._get_batch() |
| | self.batches_outstanding -= 1 |
| | if idx != self.rcvd_idx: |
| | |
| | self.reorder_dict[idx] = batch |
| | continue |
| | return self._process_next_batch(batch) |
| |
|
| | next = __next__ |
| |
|
| | def __iter__(self): |
| | return self |
| |
|
| | def _put_indices(self): |
| | assert self.batches_outstanding < 2 * self.num_workers |
| | indices = next(self.sample_iter, None) |
| | if indices is None: |
| | return |
| | self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) |
| | self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers |
| | self.batches_outstanding += 1 |
| | self.send_idx += 1 |
| |
|
| | def _process_next_batch(self, batch): |
| | self.rcvd_idx += 1 |
| | self._put_indices() |
| | if isinstance(batch, ExceptionWrapper): |
| | raise batch.exc_type(batch.exc_msg) |
| | return batch |
| |
|
| | def __getstate__(self): |
| | |
| | |
| | |
| | |
| | |
| | raise NotImplementedError("_SequentialDataLoaderIter cannot be pickled") |
| |
|
| | def _shutdown_workers(self): |
| | try: |
| | if not self.shutdown: |
| | self.shutdown = True |
| | self.done_event.set() |
| | for q in self.index_queues: |
| | q.put(None) |
| | |
| | try: |
| | while not self.worker_result_queue.empty(): |
| | self.worker_result_queue.get() |
| | except (FileNotFoundError, ImportError): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | pass |
| | |
| | |
| | self.worker_result_queue.put(None) |
| | finally: |
| | |
| | if self.worker_pids_set: |
| | _remove_worker_pids(id(self)) |
| | self.worker_pids_set = False |
| |
|
| | def __del__(self): |
| | if self.num_workers > 0: |
| | self._shutdown_workers() |
| |
|
| |
|
| | class SequentialDataLoader(object): |
| | r""" |
| | Sequential Data loader. Combines a dataset and a sampler, and provides |
| | single- or multi-process iterators over the dataset. |
| | This is modified from Pytorch.DataLoader by disable random state touch as for sequential data loading, |
| | we don't want it to touch any random state. |
| | Arguments: |
| | dataset (Dataset): dataset from which to load the data. |
| | batch_size (int, optional): how many samples per batch to load |
| | (default: 1). |
| | shuffle (bool, optional): set to ``True`` to have the data reshuffled |
| | at every epoch (default: False). |
| | sampler (Sampler, optional): defines the strategy to draw samples from |
| | the dataset. If specified, ``shuffle`` must be False. |
| | batch_sampler (Sampler, optional): like sampler, but returns a batch of |
| | indices at a time. Mutually exclusive with batch_size, shuffle, |
| | sampler, and drop_last. |
| | num_workers (int, optional): how many subprocesses to use for data |
| | loading. 0 means that the data will be loaded in the main process. |
| | (default: 0) |
| | collate_fn (callable, optional): merges a list of samples to form a mini-batch. |
| | pin_memory (bool, optional): If ``True``, the data loader will copy tensors |
| | into CUDA pinned memory before returning them. |
| | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, |
| | if the dataset size is not divisible by the batch size. If ``False`` and |
| | the size of dataset is not divisible by the batch size, then the last batch |
| | will be smaller. (default: False) |
| | timeout (numeric, optional): if positive, the timeout value for collecting a batch |
| | from workers. Should always be non-negative. (default: 0) |
| | worker_init_fn (callable, optional): If not None, this will be called on each |
| | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as |
| | input, after seeding and before data loading. (default: None) |
| | |
| | .. note:: By default, each worker will have its PyTorch seed set to |
| | ``base_seed + worker_id``, where ``base_seed`` is a long generated |
| | by main process using its RNG. However, seeds for other libraies |
| | may be duplicated upon initializing workers (w.g., NumPy), causing |
| | each worker to return identical random numbers. (See |
| | :ref:`dataloader-workers-random-seed` section in FAQ.) You may |
| | use ``torch.initial_seed()`` to access the PyTorch seed for each |
| | worker in :attr:`worker_init_fn`, and use it to set other seeds |
| | before data loading. |
| | |
| | .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an |
| | unpicklable object, e.g., a lambda function. |
| | """ |
| |
|
| | __initialized = False |
| |
|
| | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, |
| | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, |
| | timeout=0, worker_init_fn=None): |
| | self.dataset = dataset |
| | self.batch_size = batch_size |
| | self.num_workers = num_workers |
| | self.collate_fn = collate_fn |
| | self.pin_memory = pin_memory |
| | self.drop_last = drop_last |
| | self.timeout = timeout |
| | self.worker_init_fn = worker_init_fn |
| |
|
| | if timeout < 0: |
| | raise ValueError('timeout option should be non-negative') |
| |
|
| | if batch_sampler is not None: |
| | if batch_size > 1 or shuffle or sampler is not None or drop_last: |
| | raise ValueError('batch_sampler option is mutually exclusive ' |
| | 'with batch_size, shuffle, sampler, and ' |
| | 'drop_last') |
| | self.batch_size = None |
| | self.drop_last = None |
| |
|
| | if sampler is not None and shuffle: |
| | raise ValueError('sampler option is mutually exclusive with ' |
| | 'shuffle') |
| |
|
| | if self.num_workers < 0: |
| | raise ValueError('num_workers option cannot be negative; ' |
| | 'use num_workers=0 to disable multiprocessing.') |
| |
|
| | if batch_sampler is None: |
| | if sampler is None: |
| | if shuffle: |
| | sampler = RandomSampler(dataset) |
| | else: |
| | sampler = SequentialSampler(dataset) |
| | batch_sampler = BatchSampler(sampler, batch_size, drop_last) |
| |
|
| | self.sampler = sampler |
| | self.batch_sampler = batch_sampler |
| | self.__initialized = True |
| |
|
| | def __setattr__(self, attr, val): |
| | if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): |
| | raise ValueError('{} attribute should not be set after {} is ' |
| | 'initialized'.format(attr, self.__class__.__name__)) |
| |
|
| | super(SequentialDataLoader, self).__setattr__(attr, val) |
| |
|
| | def __iter__(self): |
| | return _SequentialDataLoaderIter(self) |
| |
|
| | def __len__(self): |
| | return len(self.batch_sampler) |
| |
|
| |
|