| |
|
| |
|
| | from torch.utils.data.dataset import IterableDataset |
| | from torch.utils.data import DataLoader |
| | import numpy as np |
| |
|
| | from sudoku.loader import DataIterBuffer, train_dataset, test_dataset, data_loader, get_datasets |
| |
|
| |
|
| |
|
| | class CustomDataLoader(DataLoader): |
| | def __init__(self, data_iters, batch_size): |
| | self.data_iters=data_iters |
| | self.batch_size=batch_size |
| | self.data_loaders = [iter(DataLoader(data_iter, batch_size=batch_size)) for data_iter in data_iters] |
| | def __iter__(self): |
| | while True: |
| | buffer_sizes = np.array([len(buffer) for buffer in self.data_iters]) |
| | if any(buffer_sizes>=self.batch_size): |
| | idx_yield = len(buffer_sizes)-1-np.argmax(buffer_sizes[::-1]>=self.batch_size) |
| | |
| | else : |
| | idx_yield = np.argmax(buffer_sizes) |
| | yield [idx_yield]+ next(self.data_loaders[idx_yield]) |
| | |
| | @property |
| | def num_workers(self): |
| | return 0 |
| |
|