Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,537 Bytes
fc605f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import os
from subprocess import check_call
import torchaudio
from datasets import load_dataset
from torch.utils.data import Dataset
from torchcodec.decoders import AudioDecoder
def cache_file(url, outfile):
if not os.path.exists(outfile):
print("Downloading musdb18hq dataset...")
os.makedirs(os.path.dirname(outfile), exist_ok=True)
check_call(["curl", "--url", url, "--output", outfile + ".tmp"])
os.rename(outfile + ".tmp", outfile)
class MUSDB(Dataset):
def __init__(
self,
collate_fn,
sample_rate: int = 48_000,
cache_path: str = os.path.expanduser("~/.cache/sam_audio"),
):
self.cache_path = os.path.join(cache_path, "musdb18hq")
self.ds = self.get_dataset(cache_path)
self.captions = ["bass", "drums", "vocals"]
self.collate_fn = collate_fn
self.sample_rate = sample_rate
@property
def visual(self):
return False
def get_dataset(self, cache_path):
zip_file = os.path.join(cache_path, "musdb18hq.zip")
url = "https://zenodo.org/records/3338373/files/musdb18hq.zip?download=1"
cache_file(url, zip_file)
extracted_dir = os.path.join(cache_path, "musdb18hq")
if not os.path.exists(extracted_dir):
check_call(["unzip", zip_file, "-d", extracted_dir + ".tmp"])
os.rename(extracted_dir + ".tmp", extracted_dir)
return load_dataset("facebook/sam-audio-musdb18hq-test")["test"]
def __len__(self):
return len(self.ds)
def collate(self, items):
audios, descriptions = zip(*items, strict=False)
return self.collate_fn(
audios=audios,
descriptions=descriptions,
)
def __getitem__(self, idx):
item = self.ds[idx]
path = os.path.join(self.cache_path, "test", item["id"], "mixture.wav")
assert os.path.exists(path), f"{path} does not exist!"
decoder = AudioDecoder(path)
data = decoder.get_samples_played_in_range(item["start_time"], item["end_time"])
wav = data.data
if data.sample_rate != self.sample_rate:
wav = torchaudio.functional.resample(
wav, data.sample_rate, self.sample_rate
)
wav = wav.mean(0, keepdim=True)
return wav, item["description"]
if __name__ == "__main__":
dataset = MUSDB(lambda **kwargs: None)
print(len(dataset))
print(dataset[0])
|