Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| import os | |
| from subprocess import check_call | |
| import torchaudio | |
| from datasets import load_dataset | |
| from torch.utils.data import Dataset | |
| from torchcodec.decoders import AudioDecoder | |
| def cache_file(url, outfile): | |
| if not os.path.exists(outfile): | |
| print("Downloading musdb18hq dataset...") | |
| os.makedirs(os.path.dirname(outfile), exist_ok=True) | |
| check_call(["curl", "--url", url, "--output", outfile + ".tmp"]) | |
| os.rename(outfile + ".tmp", outfile) | |
| class MUSDB(Dataset): | |
| def __init__( | |
| self, | |
| collate_fn, | |
| sample_rate: int = 48_000, | |
| cache_path: str = os.path.expanduser("~/.cache/sam_audio"), | |
| ): | |
| self.cache_path = os.path.join(cache_path, "musdb18hq") | |
| self.ds = self.get_dataset(cache_path) | |
| self.captions = ["bass", "drums", "vocals"] | |
| self.collate_fn = collate_fn | |
| self.sample_rate = sample_rate | |
| def visual(self): | |
| return False | |
| def get_dataset(self, cache_path): | |
| zip_file = os.path.join(cache_path, "musdb18hq.zip") | |
| url = "https://zenodo.org/records/3338373/files/musdb18hq.zip?download=1" | |
| cache_file(url, zip_file) | |
| extracted_dir = os.path.join(cache_path, "musdb18hq") | |
| if not os.path.exists(extracted_dir): | |
| check_call(["unzip", zip_file, "-d", extracted_dir + ".tmp"]) | |
| os.rename(extracted_dir + ".tmp", extracted_dir) | |
| return load_dataset("facebook/sam-audio-musdb18hq-test")["test"] | |
| def __len__(self): | |
| return len(self.ds) | |
| def collate(self, items): | |
| audios, descriptions = zip(*items, strict=False) | |
| return self.collate_fn( | |
| audios=audios, | |
| descriptions=descriptions, | |
| ) | |
| def __getitem__(self, idx): | |
| item = self.ds[idx] | |
| path = os.path.join(self.cache_path, "test", item["id"], "mixture.wav") | |
| assert os.path.exists(path), f"{path} does not exist!" | |
| decoder = AudioDecoder(path) | |
| data = decoder.get_samples_played_in_range(item["start_time"], item["end_time"]) | |
| wav = data.data | |
| if data.sample_rate != self.sample_rate: | |
| wav = torchaudio.functional.resample( | |
| wav, data.sample_rate, self.sample_rate | |
| ) | |
| wav = wav.mean(0, keepdim=True) | |
| return wav, item["description"] | |
| if __name__ == "__main__": | |
| dataset = MUSDB(lambda **kwargs: None) | |
| print(len(dataset)) | |
| print(dataset[0]) | |