# 测试各种指标 import os import glob from UTMOS import UTMOSScore from periodicity import calculate_periodicity_metrics import torchaudio from pesq import pesq import numpy as np import torch import math from pystoi import stoi device=torch.device('cuda:0') # 如果是ljspeech,需要更换路径,更换数据读取逻辑,更换stoi的采样率 def main(): prepath="./Result/Minicodec/infer/dac_nq4_all" rawpath="./Data/libritts/test-clean" # rawpath="./Data/LJSpeech-1.1/wavs" preaudio = os.listdir(prepath) rawaudio = [] UTMOS=UTMOSScore(device='cuda:0') # libritts for i in range(len(preaudio)): id1=preaudio[i].split('_')[0] id2=preaudio[i].split('_')[1] rawaudio.append(rawpath+"/"+id1+"/"+id2+"/"+preaudio[i]) # # ljspeech # for i in range(len(preaudio)): # rawaudio.append(rawpath+"/"+preaudio[i]) utmos_sumgt=0 utmos_sumencodec=0 pesq_sumpre=0 f1score_sumpre=0 stoi_sumpre=[] f1score_filt=0 for i in range(len(preaudio)): print(i) rawwav,rawwav_sr=torchaudio.load(rawaudio[i]) prewav,prewav_sr=torchaudio.load(prepath+"/"+preaudio[i]) # breakpoint() rawwav=rawwav.to(device) prewav=prewav.to(device) # print(rawwav.size(),prewav.size()) # breakpoint() rawwav_16k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=16000) #测试UTMOS的时候必须重采样 prewav_16k=torchaudio.functional.resample(prewav, orig_freq=prewav_sr, new_freq=16000) # 1.UTMOS print("****UTMOS_raw",i,UTMOS.score(rawwav_16k.unsqueeze(1))[0].item()) print("****UTMOS_encodec",i,UTMOS.score(prewav_16k.unsqueeze(1))[0].item()) utmos_sumgt+=UTMOS.score(rawwav_16k.unsqueeze(1))[0].item() utmos_sumencodec+=UTMOS.score(prewav_16k.unsqueeze(1))[0].item() # breakpoint() ## 2.PESQ min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1]) rawwav_16k_pesq=rawwav_16k[:,:min_len].squeeze(0) prewav_16k_pesq=prewav_16k[:,:min_len].squeeze(0) pesq_score = pesq(16000, rawwav_16k_pesq.cpu().numpy(), prewav_16k_pesq.cpu().numpy(), "wb", on_error=1) print("****PESQ",i,pesq_score) pesq_sumpre+=pesq_score # breakpoint() ## 3.F1-score min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1]) rawwav_16k_f1score=rawwav_16k[:,:min_len] prewav_16k_f1score=prewav_16k[:,:min_len] periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(rawwav_16k_f1score,prewav_16k_f1score) print("****f1",periodicity_loss, pitch_loss, f1_score,f1score_sumpre) if(math.isnan(f1_score)): f1score_filt+=1 print("*****",f1score_filt) else: f1score_sumpre+=f1_score # breakpoint() ## 4.STOI # # 针对重采样的ljspeech # rawwav_24k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=24000) # min_len=min(rawwav_24k.size()[1],prewav.size()[1]) # rawwav_stoi=rawwav_24k[:,:min_len].squeeze(0) # prewav_stoi=prewav[:,:min_len].squeeze(0) # tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),24000,extended=False) # print("****stoi",tmp_stoi) # stoi_sumpre.append(tmp_stoi) # # breakpoint() # 针对libritts采样率是24k的 min_len=min(rawwav.size()[1],prewav.size()[1]) rawwav_stoi=rawwav[:,:min_len].squeeze(0) prewav_stoi=prewav[:,:min_len].squeeze(0) tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),rawwav_sr,extended=False) print("****stoi",tmp_stoi) stoi_sumpre.append(tmp_stoi) print("*************UTMOS_raw",utmos_sumgt,utmos_sumgt/len(preaudio)) print("*************UTMOS_encodec",utmos_sumgt,utmos_sumencodec/len(preaudio)) print("*************PESQ:",pesq_sumpre,pesq_sumpre/len(preaudio)) print("*************F1_score:",f1score_sumpre,f1score_sumpre/(len(preaudio)-f1score_filt),f1score_filt) print("*************STOI:",np.mean(stoi_sumpre)) if __name__=="__main__": main()