| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | from fractions import Fraction |
| | from concurrent import futures |
| |
|
| | import musdb |
| | from torch import distributed |
| |
|
| | from .audio import AudioFile |
| |
|
| |
|
| | def get_musdb_tracks(root, *args, **kwargs): |
| | mus = musdb.DB(root, *args, **kwargs) |
| | return {track.name: track.path for track in mus} |
| |
|
| |
|
| | class StemsSet: |
| | def __init__(self, tracks, metadata, duration=None, stride=1, |
| | samplerate=44100, channels=2, streams=slice(None)): |
| |
|
| | self.metadata = [] |
| | for name, path in tracks.items(): |
| | meta = dict(metadata[name]) |
| | meta["path"] = path |
| | meta["name"] = name |
| | self.metadata.append(meta) |
| | if duration is not None and meta["duration"] < duration: |
| | raise ValueError(f"Track {name} duration is too small {meta['duration']}") |
| | self.metadata.sort(key=lambda x: x["name"]) |
| | self.duration = duration |
| | self.stride = stride |
| | self.channels = channels |
| | self.samplerate = samplerate |
| | self.streams = streams |
| |
|
| | def __len__(self): |
| | return sum(self._examples_count(m) for m in self.metadata) |
| |
|
| | def _examples_count(self, meta): |
| | if self.duration is None: |
| | return 1 |
| | else: |
| | return int((meta["duration"] - self.duration) // self.stride + 1) |
| |
|
| | def track_metadata(self, index): |
| | for meta in self.metadata: |
| | examples = self._examples_count(meta) |
| | if index >= examples: |
| | index -= examples |
| | continue |
| | return meta |
| |
|
| | def __getitem__(self, index): |
| | for meta in self.metadata: |
| | examples = self._examples_count(meta) |
| | if index >= examples: |
| | index -= examples |
| | continue |
| | streams = AudioFile(meta["path"]).read(seek_time=index * self.stride, |
| | duration=self.duration, |
| | channels=self.channels, |
| | samplerate=self.samplerate, |
| | streams=self.streams) |
| | return (streams - meta["mean"]) / meta["std"] |
| |
|
| |
|
| | def _get_track_metadata(path): |
| | |
| | |
| | audio = AudioFile(path) |
| | mix = audio.read(streams=0, channels=1, samplerate=44100) |
| | return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()} |
| |
|
| |
|
| | def _build_metadata(tracks, workers=10): |
| | pendings = [] |
| | with futures.ProcessPoolExecutor(workers) as pool: |
| | for name, path in tracks.items(): |
| | pendings.append((name, pool.submit(_get_track_metadata, path))) |
| | return {name: p.result() for name, p in pendings} |
| |
|
| |
|
| | def _build_musdb_metadata(path, musdb, workers): |
| | tracks = get_musdb_tracks(musdb) |
| | metadata = _build_metadata(tracks, workers) |
| | path.parent.mkdir(exist_ok=True, parents=True) |
| | json.dump(metadata, open(path, "w")) |
| |
|
| |
|
| | def get_compressed_datasets(args, samples): |
| | metadata_file = args.metadata / "musdb.json" |
| | if not metadata_file.is_file() and args.rank == 0: |
| | _build_musdb_metadata(metadata_file, args.musdb, args.workers) |
| | if args.world_size > 1: |
| | distributed.barrier() |
| | metadata = json.load(open(metadata_file)) |
| | duration = Fraction(samples, args.samplerate) |
| | stride = Fraction(args.data_stride, args.samplerate) |
| | train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), |
| | metadata, |
| | duration=duration, |
| | stride=stride, |
| | streams=slice(1, None), |
| | samplerate=args.samplerate, |
| | channels=args.audio_channels) |
| | valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), |
| | metadata, |
| | samplerate=args.samplerate, |
| | channels=args.audio_channels) |
| | return train_set, valid_set |
| |
|