Spaces:
Runtime error
Runtime error
| from torch.utils.data.sampler import RandomSampler, Sampler | |
| import numpy as np | |
| class FixedLenRandomSampler(RandomSampler): | |
| """ | |
| Code from mnpinto - Miguel | |
| https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10 | |
| """ | |
| def __init__(self, data_source, bs, epoch_size, *args, **kwargs): | |
| super().__init__(data_source) | |
| self.epoch_size = epoch_size | |
| self.bs = bs | |
| self.not_sampled = np.array([True]*len(data_source)) | |
| self.size_to_sample = self.epoch_size * self.bs | |
| def _reset_state(self): | |
| self.not_sampled[:] = True | |
| def __iter__(self): | |
| ns = sum(self.not_sampled) | |
| idx_last = [] | |
| if ns >= self.size_to_sample: | |
| idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample, replace=False).tolist() | |
| if ns == self.size_to_sample: | |
| self._reset_state | |
| else: | |
| idx_last = np.where(self.not_sampled)[0].tolist() | |
| self._reset_state | |
| idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample-len(idx_last), replace=False).tolist() | |
| self.not_sampled[idx] = False | |
| idx = [*idx_last, *idx] | |
| # print(ns, len(idx), len(idx_last)) # debug | |
| out = [] | |
| i_idx = 0 | |
| for i in range(self.epoch_size): | |
| batch = [] | |
| for j in range(self.bs): | |
| batch.append(idx[i_idx]) | |
| i_idx += 1 | |
| out.append(batch) | |
| return iter(out) | |
| def __len__(self): | |
| return self.epoch_size | |