VoxSIM / predict.py
junseok520's picture
Update predict.py
5121459 verified
raw
history blame
4.26 kB
import argparse
import pathlib
import tqdm
from torch.utils.data import Dataset, DataLoader
from score import loadWav, Score
import torch
import os
import warnings
warnings.filterwarnings("ignore")
def get_arg():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str, help="predict mode")
parser.add_argument("--ckpt_path", required=False, default="voxsim_wavlm_ecapa.model", type=str, help="path to the model checkpoint")
parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path, help="input directory when predict_dir mode")
parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path, help="reference directory when predict_dir mode")
parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path, help="input file when predict_file mode")
parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path, help="reference file when predict_file mode")
parser.add_argument("--out_path", required=True, type=pathlib.Path, help="output path")
parser.add_argument("--num_workers", required=False, default=4, type=int, help="number of workers for dataloader")
return parser.parse_args()
class AudioDataset(Dataset):
def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
self.inp_dir_path = inp_dir_path
self.ref_dir_path = ref_dir_path
self.inp_wavlist = [file for file in os.listdir(inp_dir_path) if file.endswith(".wav")]
inp_wavset = set(self.inp_wavlist)
ref_wavset = set([file for file in os.listdir(ref_dir_path) if file.endswith(".wav")])
diff = inp_wavset - ref_wavset
if diff:
diff = list(diff)
diff.sort()
raise ValueError(f"Files {diff} are in inp_dir but not in ref_dir.")
self.inp_wavlist.sort()
self.max_audio = max_frames * 160 + 240
def __len__(self):
return len(self.inp_wavlist)
def __getitem__(self, idx):
inp_wavs, inp_wav = loadWav(os.path.join(self.inp_dir_path, self.inp_wavlist[idx]))
ref_wavs, ref_wav = loadWav(os.path.join(self.ref_dir_path, self.inp_wavlist[idx]))
return inp_wavs, inp_wav, ref_wavs, ref_wav
def main():
args = get_arg()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.mode == "predict_file":
assert args.inp_path is not None, "inp_path is required when mode is predict_file."
assert args.ref_path is not None, "ref_path is required when mode is predict_file."
assert args.inp_path.exists()
assert args.ref_path.exists()
assert args.inp_path.is_file()
assert args.ref_path.is_file()
inp_wavs, inp_wav = loadWav(args.inp_path)
ref_wavs, ref_wav = loadWav(args.ref_path)
scorer = Score(ckpt_path=args.ckpt_path, device=device)
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
print("VoxSIM score: ", score)
with open(args.out_path, "w") as fw:
fw.write(str(score))
else:
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
assert args.inp_dir.exists()
assert args.ref_dir.exists()
assert args.inp_dir.is_dir()
assert args.ref_dir.is_dir()
dataset = AudioDataset(args.inp_dir, args.ref_dir)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers)
scorer = Score(ckpt_path=args.ckpt_path, device=device)
avg_score = []
with open(args.out_path, 'w') as fw:
for batch in tqdm.tqdm(loader):
inp_wavs, inp_wav, ref_wavs, ref_wav = batch
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
avg_score.append(score)
fw.write(str(score) + "\n")
print("Average VoxSIM score: ", sum(avg_score)/len(avg_score))
print("save to ", args.out_path)
if __name__ == "__main__":
main()