Spaces:
Running
Running
| import pytest | |
| import threading | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| from itertools import product | |
| from ding.utils import EasyTimer | |
| from ding.utils.data import AsyncDataLoader | |
| batch_size_args = [3, 6] | |
| num_workers_args = [0, 4] | |
| chunk_size_args = [1, 3] | |
| args = [item for item in product(*[batch_size_args, num_workers_args, chunk_size_args])] | |
| unittest_args = [item for item in product(*[[3], [2], [1]])] | |
| class Dataset(object): | |
| def __init__(self): | |
| self.data = torch.randn(256, 256) | |
| def __len__(self): | |
| return 100 | |
| def __getitem__(self, idx): | |
| time.sleep(0.5) | |
| return [self.data, idx] | |
| class TestAsyncDataLoader: | |
| def get_data_source(self): | |
| dataset = Dataset() | |
| def data_source_fn(batch_size): | |
| return [partial(dataset.__getitem__, idx=i) for i in range(batch_size)] | |
| return data_source_fn | |
| def get_model(self): | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super(Model, self).__init__() | |
| self.main = [nn.Linear(256, 256) for _ in range(10)] | |
| self.main = nn.Sequential(*self.main) | |
| def forward(self, x): | |
| idx = x[1] | |
| x = self.main(x[0]) | |
| time.sleep(1) | |
| return [x, idx] | |
| return Model() | |
| # @pytest.mark.unittest | |
| def test_cpu(self, batch_size, num_workers, chunk_size): | |
| self.entry(batch_size, num_workers, chunk_size, use_cuda=False) | |
| def test_gpu(self, batch_size, num_workers, chunk_size): | |
| self.entry(batch_size, num_workers, chunk_size, use_cuda=True) | |
| torch.cuda.empty_cache() | |
| def entry(self, batch_size, num_workers, chunk_size, use_cuda): | |
| model = self.get_model() | |
| if use_cuda: | |
| model.cuda() | |
| timer = EasyTimer() | |
| data_source = self.get_data_source() | |
| device = 'cuda' if use_cuda else 'cpu' | |
| dataloader = AsyncDataLoader(data_source, batch_size, device, num_workers=num_workers, chunk_size=chunk_size) | |
| count = 0 | |
| total_data_time = 0. | |
| while True: | |
| with timer: | |
| data = next(dataloader) | |
| data_time = timer.value | |
| if count > 2: # ignore start-3 time | |
| total_data_time += data_time | |
| with timer: | |
| with torch.no_grad(): | |
| _, idx = model(data) | |
| if use_cuda: | |
| idx = idx.cpu() | |
| sorted_idx = torch.sort(idx)[0] | |
| assert sorted_idx.eq(torch.arange(batch_size)).sum() == batch_size, idx | |
| model_time = timer.value | |
| print('count {}, data_time: {}, model_time: {}'.format(count, data_time, model_time)) | |
| count += 1 | |
| if count == 10: | |
| break | |
| if num_workers < 1: | |
| assert total_data_time <= 7 * batch_size * 0.5 + 7 * 0.01 - 7 * 1 | |
| else: | |
| assert total_data_time <= 7 * 0.008 | |
| dataloader.__del__() | |
| time.sleep(0.5) | |
| assert len(threading.enumerate()) <= 2, threading.enumerate() | |