import os import torch from .data_utils import trivial_batch_collator, worker_init_reset_seed datasets = {} def register_dataset(name): def decorator(cls): datasets[name] = cls return cls return decorator def make_dataset(name, is_training, split, **kwargs): """ A simple dataset builder """ dataset = datasets[name](is_training, split, **kwargs) return dataset def make_data_loader(dataset, is_training, generator, batch_size, num_workers): """ A simple dataloder builder """ loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=trivial_batch_collator, worker_init_fn=(worker_init_reset_seed if is_training else None), shuffle=is_training, drop_last=is_training, generator=generator, persistent_workers=True ) return loader