Spaces:
Runtime error
Runtime error
| import random | |
| from pathlib import Path | |
| from typing import List | |
| import musdb | |
| import torch | |
| from torch.utils.data import Dataset | |
| class MusdbDataset(Dataset): | |
| def __init__( | |
| self, root="data/musdb18-wav", is_train: bool = True, targets: List[str] = None | |
| ) -> None: | |
| super().__init__() | |
| root = Path(root) | |
| assert root.exists(), f"Path does not exist: {root}" | |
| self.mus = musdb.DB( | |
| root=root, | |
| subsets=["train" if is_train else "test"], | |
| is_wav=True, | |
| ) | |
| self.targets = [s for s in targets] if targets else ["vocals", "accompaniment"] | |
| def __len__(self) -> int: | |
| return len(self.mus) | |
| def __getitem__(self, index): | |
| track = self.mus.tracks[index] | |
| track.chunk_duration = 5.0 | |
| track.chunk_start = random.uniform(0, track.duration - track.chunk_duration) | |
| x_wav = torch.torch.tensor(track.audio.T, dtype=torch.float32) | |
| y_target_wavs = { | |
| name: torch.tensor(track.targets[name].audio.T, dtype=torch.float32) | |
| for name in self.targets | |
| } | |
| # original audio (x) and stems (y == targets) | |
| return x_wav, y_target_wavs | |