| import threading | |
| import random | |
| import torch | |
| import torch.multiprocessing as multiprocessing | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data import SequentialSampler | |
| from torch.utils.data import RandomSampler | |
| from torch.utils.data import BatchSampler | |
| from torch.utils.data import _utils | |
| from torch.utils.data.dataloader import _DataLoaderIter | |
| from torch.utils.data._utils import collate | |
| from torch.utils.data._utils import signal_handling | |
| from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL | |
| from torch.utils.data._utils import ExceptionWrapper | |
| from torch.utils.data._utils import IS_WINDOWS | |
| from torch.utils.data._utils.worker import ManagerWatchdog | |
| from torch._six import queue | |
| def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): | |
| try: | |
| collate._use_shared_memory = True | |
| signal_handling._set_worker_signal_handlers() | |
| torch.set_num_threads(1) | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| data_queue.cancel_join_thread() | |
| if init_fn is not None: | |
| init_fn(worker_id) | |
| watchdog = ManagerWatchdog() | |
| while watchdog.is_alive(): | |
| try: | |
| r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) | |
| except queue.Empty: | |
| continue | |
| if r is None: | |
| assert done_event.is_set() | |
| return | |
| elif done_event.is_set(): | |
| continue | |
| idx, batch_indices = r | |
| try: | |
| idx_scale = 0 | |
| if len(scale) > 1 and dataset.train: | |
| idx_scale = random.randrange(0, len(scale)) | |
| dataset.set_scale(idx_scale) | |
| samples = collate_fn([dataset[i] for i in batch_indices]) | |
| samples.append(idx_scale) | |
| except Exception: | |
| data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) | |
| else: | |
| data_queue.put((idx, samples)) | |
| del samples | |
| except KeyboardInterrupt: | |
| pass | |
| class _MSDataLoaderIter(_DataLoaderIter): | |
| def __init__(self, loader): | |
| self.dataset = loader.dataset | |
| self.scale = loader.scale | |
| self.collate_fn = loader.collate_fn | |
| self.batch_sampler = loader.batch_sampler | |
| self.num_workers = loader.num_workers | |
| self.pin_memory = loader.pin_memory and torch.cuda.is_available() | |
| self.timeout = loader.timeout | |
| self.sample_iter = iter(self.batch_sampler) | |
| base_seed = torch.LongTensor(1).random_().item() | |
| if self.num_workers > 0: | |
| self.worker_init_fn = loader.worker_init_fn | |
| self.worker_queue_idx = 0 | |
| self.worker_result_queue = multiprocessing.Queue() | |
| self.batches_outstanding = 0 | |
| self.worker_pids_set = False | |
| self.shutdown = False | |
| self.send_idx = 0 | |
| self.rcvd_idx = 0 | |
| self.reorder_dict = {} | |
| self.done_event = multiprocessing.Event() | |
| base_seed = torch.LongTensor(1).random_()[0] | |
| self.index_queues = [] | |
| self.workers = [] | |
| for i in range(self.num_workers): | |
| index_queue = multiprocessing.Queue() | |
| index_queue.cancel_join_thread() | |
| w = multiprocessing.Process( | |
| target=_ms_loop, | |
| args=( | |
| self.dataset, | |
| index_queue, | |
| self.worker_result_queue, | |
| self.done_event, | |
| self.collate_fn, | |
| self.scale, | |
| base_seed + i, | |
| self.worker_init_fn, | |
| i | |
| ) | |
| ) | |
| w.daemon = True | |
| w.start() | |
| self.index_queues.append(index_queue) | |
| self.workers.append(w) | |
| if self.pin_memory: | |
| self.data_queue = queue.Queue() | |
| pin_memory_thread = threading.Thread( | |
| target=_utils.pin_memory._pin_memory_loop, | |
| args=( | |
| self.worker_result_queue, | |
| self.data_queue, | |
| torch.cuda.current_device(), | |
| self.done_event | |
| ) | |
| ) | |
| pin_memory_thread.daemon = True | |
| pin_memory_thread.start() | |
| self.pin_memory_thread = pin_memory_thread | |
| else: | |
| self.data_queue = self.worker_result_queue | |
| _utils.signal_handling._set_worker_pids( | |
| id(self), tuple(w.pid for w in self.workers) | |
| ) | |
| _utils.signal_handling._set_SIGCHLD_handler() | |
| self.worker_pids_set = True | |
| for _ in range(2 * self.num_workers): | |
| self._put_indices() | |
| class MSDataLoader(DataLoader): | |
| def __init__(self, cfg, *args, **kwargs): | |
| super(MSDataLoader, self).__init__( | |
| *args, **kwargs, num_workers=cfg.n_threads | |
| ) | |
| self.scale = cfg.scale | |
| def __iter__(self): | |
| return _MSDataLoaderIter(self) | |