BrainFM / Generator /__init__.py
peirong26's picture
Upload 187 files
2571f24 verified
"""
Datasets interface.
"""
from .constants import dataset_setups
from .datasets import BaseGen, BrainIDGen
dataset_options = {
'default': BaseGen,
'brain_id': BrainIDGen,
}
def build_datasets(gen_args, device):
"""Helper function to build dataset for different splits ('train' or 'test')."""
datasets = {'all': dataset_options[gen_args.dataset_option](gen_args, device)}
return datasets