Spaces:
Runtime error
Runtime error
| import logging | |
| import random | |
| from torch.utils.data import DataLoader | |
| from ..hparams import HParams | |
| from .dataset import Dataset | |
| from .utils import mix_fg_bg, rglob_audio_files | |
| logger = logging.getLogger(__name__) | |
| def _create_datasets(hp: HParams, mode, val_size=10, seed=123): | |
| paths = rglob_audio_files(hp.fg_dir) | |
| logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}") | |
| random.Random(seed).shuffle(paths) | |
| train_paths = paths[:-val_size] | |
| val_paths = paths[-val_size:] | |
| train_ds = Dataset(train_paths, hp, training=True, mode=mode) | |
| val_ds = Dataset(val_paths, hp, training=False, mode=mode) | |
| logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples") | |
| return train_ds, val_ds | |
| def create_dataloaders(hp: HParams, mode): | |
| train_ds, val_ds = _create_datasets(hp=hp, mode=mode) | |
| train_dl = DataLoader( | |
| train_ds, | |
| batch_size=hp.batch_size_per_gpu, | |
| shuffle=True, | |
| num_workers=hp.nj, | |
| drop_last=True, | |
| collate_fn=train_ds.collate_fn, | |
| ) | |
| val_dl = DataLoader( | |
| val_ds, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=hp.nj, | |
| drop_last=False, | |
| collate_fn=val_ds.collate_fn, | |
| ) | |
| return train_dl, val_dl | |