|
|
from pathlib import Path |
|
|
import random |
|
|
import shutil |
|
|
import os |
|
|
import json |
|
|
|
|
|
import argbind |
|
|
from tqdm import tqdm |
|
|
from tqdm.contrib.concurrent import thread_map |
|
|
|
|
|
from audiotools.core import util |
|
|
|
|
|
|
|
|
@argbind.bind(without_prefix=True) |
|
|
def train_test_split( |
|
|
audio_folder: str = ".", |
|
|
test_size: float = 0.2, |
|
|
seed: int = 42, |
|
|
): |
|
|
print(f"finding audio") |
|
|
|
|
|
audio_folder = Path(audio_folder) |
|
|
audio_files = util.find_audio(audio_folder) |
|
|
print(f"found {len(audio_files)} audio files") |
|
|
|
|
|
|
|
|
n_test = int(len(audio_files) * test_size) |
|
|
n_train = len(audio_files) - n_test |
|
|
|
|
|
|
|
|
random.seed(seed) |
|
|
random.shuffle(audio_files) |
|
|
|
|
|
train_files = audio_files[:n_train] |
|
|
test_files = audio_files[n_train:] |
|
|
|
|
|
|
|
|
print(f"Train files: {len(train_files)}") |
|
|
print(f"Test files: {len(test_files)}") |
|
|
continue_ = input("Continue [yn]? ") or "n" |
|
|
|
|
|
if continue_ != "y": |
|
|
return |
|
|
|
|
|
for split, files in ( |
|
|
("train", train_files), ("test", test_files) |
|
|
): |
|
|
for file in tqdm(files): |
|
|
out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name |
|
|
out_file.parent.mkdir(exist_ok=True, parents=True) |
|
|
try: |
|
|
os.symlink(file, out_file) |
|
|
except FileExistsError: |
|
|
print(f"File {out_file} already exists, skipping") |
|
|
|
|
|
|
|
|
with open(Path(audio_folder) / f"{split}.json", "w") as f: |
|
|
json.dump([str(f) for f in files], f) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = argbind.parse_args() |
|
|
|
|
|
with argbind.scope(args): |
|
|
train_test_split() |