Spaces:
Sleeping
Sleeping
File size: 1,753 Bytes
541c6b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import csv
from pathlib import Path
import argbind
import torch
from audiotools.core import util
from audiotools.ml.decorators import Tracker
from train import Accelerator
import scripts.train as train
@torch.no_grad()
def process(batch, accel, test_data):
batch = util.prepare_batch(batch, accel.device)
signal = test_data.transform(batch["signal"].clone(), **batch["transform_args"])
return signal.cpu()
@argbind.bind(without_prefix=True)
@torch.no_grad()
def save_test_set(args, accel, sample_rate: int = 44100, output: str = "samples/input"):
tracker = Tracker()
with argbind.scope(args, "test"):
test_data = train.build_dataset(sample_rate)
global process
process = tracker.track("process", len(test_data))(process)
output = Path(output)
output.mkdir(parents=True, exist_ok=True)
(output.parent / "input").mkdir(parents=True, exist_ok=True)
with open(output / "metadata.csv", "w") as csvfile:
keys = ["path", "original"]
writer = csv.DictWriter(csvfile, fieldnames=keys)
writer.writeheader()
with tracker.live:
for i in range(len(test_data)):
signal = process(test_data[i], accel, test_data)
input_path = output.parent / "input" / f"sample_{i}.wav"
metadata = {
"path": str(input_path),
"original": str(signal.path_to_input_file),
}
writer.writerow(metadata)
signal.write(input_path)
tracker.done("test", f"N={len(test_data)}")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
with Accelerator() as accel:
save_test_set(args, accel)
|