| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import os |
| | from collections import defaultdict, namedtuple |
| | from pathlib import Path |
| |
|
| | import musdb |
| | import numpy as np |
| | import torch as th |
| | import tqdm |
| | from torch.utils.data import DataLoader |
| |
|
| | from .audio import AudioFile |
| |
|
| | ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"]) |
| |
|
| |
|
| | class Rawset: |
| | """ |
| | Dataset of raw, normalized, float32 audio files |
| | """ |
| | def __init__(self, path, samples=None, stride=None, channels=2, streams=None): |
| | self.path = Path(path) |
| | self.channels = channels |
| | self.samples = samples |
| | if stride is None: |
| | stride = samples if samples is not None else 0 |
| | self.stride = stride |
| | entries = defaultdict(list) |
| | for root, folders, files in os.walk(self.path, followlinks=True): |
| | folders.sort() |
| | files.sort() |
| | for file in files: |
| | if file.endswith(".raw"): |
| | path = Path(root) / file |
| | name, stream = path.stem.rsplit('.', 1) |
| | entries[(path.parent.relative_to(self.path), name)].append(int(stream)) |
| |
|
| | self._entries = list(entries.keys()) |
| |
|
| | sizes = [] |
| | self._lengths = [] |
| | ref_streams = sorted(entries[self._entries[0]]) |
| | assert ref_streams == list(range(len(ref_streams))) |
| | if streams is None: |
| | self.streams = ref_streams |
| | else: |
| | self.streams = streams |
| | for entry in sorted(entries.keys()): |
| | streams = entries[entry] |
| | assert sorted(streams) == ref_streams |
| | file = self._path(*entry) |
| | length = file.stat().st_size // (4 * channels) |
| | if samples is None: |
| | sizes.append(1) |
| | else: |
| | if length < samples: |
| | self._entries.remove(entry) |
| | continue |
| | sizes.append((length - samples) // stride + 1) |
| | self._lengths.append(length) |
| | if not sizes: |
| | raise ValueError(f"Empty dataset {self.path}") |
| | self._cumulative_sizes = np.cumsum(sizes) |
| | self._sizes = sizes |
| |
|
| | def __len__(self): |
| | return self._cumulative_sizes[-1] |
| |
|
| | @property |
| | def total_length(self): |
| | return sum(self._lengths) |
| |
|
| | def chunk_info(self, index): |
| | file_index = np.searchsorted(self._cumulative_sizes, index, side='right') |
| | if file_index == 0: |
| | local_index = index |
| | else: |
| | local_index = index - self._cumulative_sizes[file_index - 1] |
| | return ChunkInfo(offset=local_index * self.stride, |
| | file_index=file_index, |
| | local_index=local_index) |
| |
|
| | def _path(self, folder, name, stream=0): |
| | return self.path / folder / (name + f'.{stream}.raw') |
| |
|
| | def __getitem__(self, index): |
| | chunk = self.chunk_info(index) |
| | entry = self._entries[chunk.file_index] |
| |
|
| | length = self.samples or self._lengths[chunk.file_index] |
| | streams = [] |
| | to_read = length * self.channels * 4 |
| | for stream_index, stream in enumerate(self.streams): |
| | offset = chunk.offset * 4 * self.channels |
| | file = open(self._path(*entry, stream=stream), 'rb') |
| | file.seek(offset) |
| | content = file.read(to_read) |
| | assert len(content) == to_read |
| | content = np.frombuffer(content, dtype=np.float32) |
| | content = content.copy() |
| | streams.append(th.from_numpy(content).view(length, self.channels).t()) |
| | return th.stack(streams, dim=0) |
| |
|
| | def name(self, index): |
| | chunk = self.chunk_info(index) |
| | folder, name = self._entries[chunk.file_index] |
| | return folder / name |
| |
|
| |
|
| | class MusDBSet: |
| | def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2): |
| | self.mus = mus |
| | self.streams = streams |
| | self.samplerate = samplerate |
| | self.channels = channels |
| |
|
| | def __len__(self): |
| | return len(self.mus.tracks) |
| |
|
| | def __getitem__(self, index): |
| | track = self.mus.tracks[index] |
| | return (track.name, AudioFile(track.path).read(channels=self.channels, |
| | seek_time=0, |
| | streams=self.streams, |
| | samplerate=self.samplerate)) |
| |
|
| |
|
| | def build_raw(mus, destination, normalize, workers, samplerate, channels): |
| | destination.mkdir(parents=True, exist_ok=True) |
| | loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate), |
| | batch_size=1, |
| | num_workers=workers, |
| | collate_fn=lambda x: x[0]) |
| | for name, streams in tqdm.tqdm(loader): |
| | if normalize: |
| | ref = streams[0].mean(dim=0) |
| | streams = (streams - ref.mean()) / ref.std() |
| | for index, stream in enumerate(streams): |
| | open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes()) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser('rawset') |
| | parser.add_argument('--workers', type=int, default=10) |
| | parser.add_argument('--samplerate', type=int, default=44100) |
| | parser.add_argument('--channels', type=int, default=2) |
| | parser.add_argument('musdb', type=Path) |
| | parser.add_argument('destination', type=Path) |
| |
|
| | args = parser.parse_args() |
| |
|
| | build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"), |
| | args.destination / "train", |
| | normalize=True, |
| | channels=args.channels, |
| | samplerate=args.samplerate, |
| | workers=args.workers) |
| | build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"), |
| | args.destination / "valid", |
| | normalize=True, |
| | samplerate=args.samplerate, |
| | channels=args.channels, |
| | workers=args.workers) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|