| """Base class for dataset. |
| |
| See mnist.py for an example of dataset. |
| """ |
|
|
| import collections |
| import logging |
| from abc import ABCMeta, abstractmethod |
|
|
| import omegaconf |
| import torch |
| from omegaconf import OmegaConf |
| from torch.utils.data import DataLoader, Sampler, get_worker_info |
| from torch.utils.data._utils.collate import default_collate_err_msg_format, np_str_obj_array_pattern |
|
|
| from siclib.utils.tensor import string_classes |
| from siclib.utils.tools import set_num_threads, set_seed |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
|
|
| class LoopSampler(Sampler): |
| """Infinite sampler that loops over a given number of elements.""" |
|
|
| def __init__(self, loop_size: int, total_size: int = None): |
| """Initialize the sampler. |
| |
| Args: |
| loop_size (int): Number of elements to loop over. |
| total_size (int, optional): Total number of elements. Defaults to None. |
| """ |
| self.loop_size = loop_size |
| self.total_size = total_size - (total_size % loop_size) |
|
|
| def __iter__(self): |
| """Return an iterator over the elements.""" |
| return (i % self.loop_size for i in range(self.total_size)) |
|
|
| def __len__(self): |
| """Return the number of elements.""" |
| return self.total_size |
|
|
|
|
| def worker_init_fn(i): |
| """Initialize the workers with a different seed.""" |
| info = get_worker_info() |
| if hasattr(info.dataset, "conf"): |
| conf = info.dataset.conf |
| set_seed(info.id + conf.seed) |
| set_num_threads(conf.num_threads) |
| else: |
| set_num_threads(1) |
|
|
|
|
| def collate(batch): |
| """Difference with PyTorch default_collate: it can stack of other objects.""" |
| if not isinstance(batch, list): |
| return batch |
| elem = batch[0] |
| elem_type = type(elem) |
| if isinstance(elem, torch.Tensor): |
| |
| if torch.utils.data.get_worker_info() is not None: |
| |
| |
| numel = sum([x.numel() for x in batch]) |
| try: |
| _ = elem.untyped_storage()._new_shared(numel) |
| except AttributeError: |
| _ = elem.storage()._new_shared(numel) |
| return torch.stack(batch, dim=0) |
| elif ( |
| elem_type.__module__ == "numpy" |
| and elem_type.__name__ != "str_" |
| and elem_type.__name__ != "string_" |
| ): |
| if elem_type.__name__ in ["ndarray", "memmap"]: |
| |
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
| raise TypeError(default_collate_err_msg_format.format(elem.dtype)) |
| return collate([torch.as_tensor(b) for b in batch]) |
| elif elem.shape == (): |
| return torch.as_tensor(batch) |
| elif isinstance(elem, float): |
| return torch.tensor(batch, dtype=torch.float64) |
| elif isinstance(elem, int): |
| return torch.tensor(batch) |
| elif isinstance(elem, string_classes): |
| return batch |
| elif isinstance(elem, collections.abc.Mapping): |
| return {key: collate([d[key] for d in batch]) for key in elem} |
| elif isinstance(elem, tuple) and hasattr(elem, "_fields"): |
| return elem_type(*(collate(samples) for samples in zip(*batch))) |
| elif isinstance(elem, collections.abc.Sequence): |
| |
| it = iter(batch) |
| elem_size = len(next(it)) |
| if any(len(elem) != elem_size for elem in it): |
| raise RuntimeError("each element in list of batch should be of equal size") |
| transposed = zip(*batch) |
| return [collate(samples) for samples in transposed] |
| elif elem is None: |
| return elem |
| else: |
| |
| return torch.stack(batch, 0) |
|
|
|
|
| class BaseDataset(metaclass=ABCMeta): |
| """Base class for dataset. |
| |
| What the dataset model is expect to declare: |
| default_conf: dictionary of the default configuration of the dataset. |
| It overwrites base_default_conf in BaseModel, and it is overwritten by |
| the user-provided configuration passed to __init__. |
| Configurations can be nested. |
| |
| _init(self, conf): initialization method, where conf is the final |
| configuration object (also accessible with `self.conf`). Accessing |
| unknown configuration entries will raise an error. |
| |
| get_dataset(self, split): method that returns an instance of |
| torch.utils.data.Dataset corresponding to the requested split string, |
| which can be `'train'`, `'val'`, or `'test'`. |
| """ |
|
|
| base_default_conf = { |
| "name": "???", |
| "num_workers": "???", |
| "train_batch_size": "???", |
| "val_batch_size": "???", |
| "test_batch_size": "???", |
| "shuffle_training": True, |
| "batch_size": 1, |
| "num_threads": 1, |
| "seed": 0, |
| "prefetch_factor": 2, |
| } |
| default_conf = {} |
|
|
| def __init__(self, conf): |
| """Perform some logic and call the _init method of the child model.""" |
| default_conf = OmegaConf.merge( |
| OmegaConf.create(self.base_default_conf), |
| OmegaConf.create(self.default_conf), |
| ) |
| OmegaConf.set_struct(default_conf, True) |
| if isinstance(conf, dict): |
| conf = OmegaConf.create(conf) |
| self.conf = OmegaConf.merge(default_conf, conf) |
| OmegaConf.set_readonly(self.conf, True) |
| logger.info(f"Creating dataset {self.__class__.__name__}") |
| self._init(self.conf) |
|
|
| @abstractmethod |
| def _init(self, conf): |
| """To be implemented by the child class.""" |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def get_dataset(self, split): |
| """To be implemented by the child class.""" |
| raise NotImplementedError |
|
|
| def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): |
| """Return a data loader for a given split.""" |
| assert split in ["train", "val", "test"] |
| dataset = self.get_dataset(split) |
| try: |
| batch_size = self.conf[f"{split}_batch_size"] |
| except omegaconf.MissingMandatoryValue: |
| batch_size = self.conf.batch_size |
| num_workers = self.conf.get("num_workers", batch_size) |
| if distributed: |
| shuffle = False |
| sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
| else: |
| sampler = None |
| if shuffle is None: |
| shuffle = split == "train" and self.conf.shuffle_training |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| sampler=sampler, |
| pin_memory=pinned, |
| collate_fn=collate, |
| num_workers=num_workers, |
| worker_init_fn=worker_init_fn, |
| prefetch_factor=self.conf.prefetch_factor, |
| ) |
|
|
| def get_overfit_loader(self, split: str): |
| """Return an overfit data loader. |
| |
| The training set is composed of a single duplicated batch, while |
| the validation and test sets contain a single copy of this same batch. |
| This is useful to debug a model and make sure that losses and metrics |
| correlate well. |
| """ |
| assert split in {"train", "val", "test"} |
| dataset = self.get_dataset("train") |
| sampler = LoopSampler( |
| self.conf.batch_size, |
| len(dataset) if split == "train" else self.conf.batch_size, |
| ) |
| num_workers = self.conf.get("num_workers", self.conf.batch_size) |
| return DataLoader( |
| dataset, |
| batch_size=self.conf.batch_size, |
| pin_memory=True, |
| num_workers=num_workers, |
| sampler=sampler, |
| worker_init_fn=worker_init_fn, |
| collate_fn=collate, |
| ) |
|
|