Spaces:
Sleeping
Sleeping
| import argparse | |
| import pathlib | |
| import tqdm | |
| from torch.utils.data import Dataset, DataLoader | |
| import librosa | |
| import numpy | |
| from score import Score | |
| import torch | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def get_arg(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--bs", required=False, default=None, type=int) | |
| parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str) | |
| parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path) | |
| parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path) | |
| parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path) | |
| parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path) | |
| parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path) | |
| parser.add_argument("--out_path", required=True, type=pathlib.Path) | |
| parser.add_argument("--num_workers", required=False, default=0, type=int) | |
| return parser.parse_args() | |
| def loadWav(filename, max_frames: int = 400): | |
| # Maximum audio length | |
| max_audio = max_frames * 160 + 240 | |
| # Read wav file and convert to torch tensor | |
| if type(filename) == tuple: | |
| sr, audio = filename | |
| audio = librosa.util.normalize(audio) | |
| else: | |
| audio, sr = librosa.load(filename, sr=16000) | |
| audio_org = audio.copy() | |
| audiosize = audio.shape[0] | |
| if audiosize <= max_audio: | |
| shortage = max_audio - audiosize + 1 | |
| audio = numpy.pad(audio, (0, shortage), 'wrap') | |
| audiosize = audio.shape[0] | |
| startframe = numpy.linspace(0,audiosize-max_audio,num=10) | |
| feats = [] | |
| for asf in startframe: | |
| feats.append(audio[int(asf):int(asf)+max_audio]) | |
| feat = numpy.stack(feats,axis=0).astype(numpy.float32) | |
| return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32)) | |
| class AudioDataset(Dataset): | |
| def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400): | |
| self.inp_wavlist = list(inp_dir_path.glob("*.wav")) | |
| self.ref_wavlist = list(ref_dir_path.glob("*.wav")) | |
| assert len(self.inp_wavlist) == len(self.ref_wavlist) | |
| self.inp_wavlist.sort() | |
| self.ref_wavlist.sort() | |
| _, self.sr = librosa.load(self.inp_wavlist[0], sr=None) | |
| self.max_audio = max_frames * 160 + 240 | |
| def __len__(self): | |
| return len(self.inp_wavlist) | |
| def __getitem__(self, idx): | |
| inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx]) | |
| ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx]) | |
| return inp_wavs, inp_wav, ref_wavs, ref_wav | |
| def main(): | |
| args = get_arg() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if args.mode == "predict_file": | |
| assert args.inp_path is not None | |
| assert args.ref_path is not None | |
| assert args.inp_dir is None | |
| assert args.ref_dir is None | |
| assert args.inp_path.exists() | |
| assert args.inp_path.is_file() | |
| assert args.ref_path.exists() | |
| assert args.ref_path.is_file() | |
| inp_wavs, inp_wav = loadWav(args.inp_path) | |
| ref_wavs, ref_wav = loadWav(args.ref_path) | |
| scorer = Score(ckpt_path=args.ckpt_path, device=device) | |
| score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav) | |
| print("Voxsim score: ", score[0]) | |
| with open(args.out_path, "w") as fw: | |
| fw.write(str(score[0])) | |
| else: | |
| assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir." | |
| assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir." | |
| assert args.bs is not None, "bs is required when mode is predict_dir." | |
| assert args.inp_path is None, "inp_path should be None" | |
| assert args.ref_path is None, "ref_path should be None" | |
| assert args.inp_dir.exists() | |
| assert args.ref_dir.exists() | |
| assert args.inp_dir.is_dir() | |
| assert args.ref_dir.is_dir() | |
| dataset = AudioDataset(args.inp_dir, args.ref_dir) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=args.bs, | |
| shuffle=False, | |
| num_workers=args.num_workers) | |
| scorer = Score(ckpt_path=args.ckpt_path, device=device) | |
| with open(args.out_path, 'w'): | |
| pass | |
| for batch in tqdm.tqdm(loader): | |
| scores = score.score(batch.to(device)) | |
| with open(args.out_path, 'a') as fw: | |
| for s in scores: | |
| fw.write(str(s) + "\n") | |
| print("save to ", args.out_path) | |
| if __name__ == "__main__": | |
| main() |