Spaces:
Sleeping
Sleeping
File size: 2,297 Bytes
5f6b40b 9583919 5f6b40b 9583919 5f6b40b 9583919 5f6b40b 9583919 5f6b40b 9583919 5f6b40b 9583919 5f6b40b 9583919 5f6b40b 9583919 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | from pathlib import Path
import argbind
import audiotools as at
import torch
import tqdm
from audiotools import AudioSignal
from torch_pitch_shift import get_fast_shifts, pitch_shift
from torch_time_stretch import get_fast_stretches, time_stretch
@argbind.bind(without_prefix=True)
def augment(
audio_folder: Path = None,
dest_folder: Path = None,
n_augmentations: int = 10,
):
"""
Augment a folder of audio files by applying audiotools and pedalboard transforms.
The dest foler will contain a folder for each of the clean dataset's files.
Under each of these folders, there will be a clean file and many augmented files.
"""
assert audio_folder is not None
assert dest_folder is not None
audio_files = at.util.find_audio(audio_folder)
for audio_file in tqdm.tqdm(audio_files):
subtree = dest_folder / audio_file.relative_to(audio_folder).parent
subdir = subtree / audio_file.stem
subdir.mkdir(parents=True, exist_ok=True)
src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
# apply pedalboard transforms
for j in range(n_augmentations):
# pitch shift between -7 and 7 semitones
import random
dst = chunk.clone()
dst.samples = pitch_shift(
dst.samples,
shift=random.choice(
get_fast_shifts(
src.sample_rate, condition=lambda x: x >= 0.25 and x <= 1.0
)
),
sample_rate=src.sample_rate,
)
dst.samples = time_stretch(
dst.samples,
stretch=random.choice(
get_fast_stretches(
src.sample_rate,
condition=lambda x: x >= 0.667 and x <= 1.5,
)
),
sample_rate=src.sample_rate,
)
dst.cpu().write(subdir / f"{i}-{j}.wav")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
augment()
|