Spaces:
Build error
Build error
| # 测试各种指标 | |
| import os | |
| import glob | |
| from UTMOS import UTMOSScore | |
| from periodicity import calculate_periodicity_metrics | |
| import torchaudio | |
| from pesq import pesq | |
| import numpy as np | |
| import torch | |
| import math | |
| from pystoi import stoi | |
| device=torch.device('cuda:0') | |
| # 如果是ljspeech,需要更换路径,更换数据读取逻辑,更换stoi的采样率 | |
| def main(): | |
| prepath="./Result/Minicodec/infer/dac_nq4_all" | |
| rawpath="./Data/libritts/test-clean" | |
| # rawpath="./Data/LJSpeech-1.1/wavs" | |
| preaudio = os.listdir(prepath) | |
| rawaudio = [] | |
| UTMOS=UTMOSScore(device='cuda:0') | |
| # libritts | |
| for i in range(len(preaudio)): | |
| id1=preaudio[i].split('_')[0] | |
| id2=preaudio[i].split('_')[1] | |
| rawaudio.append(rawpath+"/"+id1+"/"+id2+"/"+preaudio[i]) | |
| # # ljspeech | |
| # for i in range(len(preaudio)): | |
| # rawaudio.append(rawpath+"/"+preaudio[i]) | |
| utmos_sumgt=0 | |
| utmos_sumencodec=0 | |
| pesq_sumpre=0 | |
| f1score_sumpre=0 | |
| stoi_sumpre=[] | |
| f1score_filt=0 | |
| for i in range(len(preaudio)): | |
| print(i) | |
| rawwav,rawwav_sr=torchaudio.load(rawaudio[i]) | |
| prewav,prewav_sr=torchaudio.load(prepath+"/"+preaudio[i]) | |
| # breakpoint() | |
| rawwav=rawwav.to(device) | |
| prewav=prewav.to(device) | |
| # print(rawwav.size(),prewav.size()) | |
| # breakpoint() | |
| rawwav_16k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=16000) #测试UTMOS的时候必须重采样 | |
| prewav_16k=torchaudio.functional.resample(prewav, orig_freq=prewav_sr, new_freq=16000) | |
| # 1.UTMOS | |
| print("****UTMOS_raw",i,UTMOS.score(rawwav_16k.unsqueeze(1))[0].item()) | |
| print("****UTMOS_encodec",i,UTMOS.score(prewav_16k.unsqueeze(1))[0].item()) | |
| utmos_sumgt+=UTMOS.score(rawwav_16k.unsqueeze(1))[0].item() | |
| utmos_sumencodec+=UTMOS.score(prewav_16k.unsqueeze(1))[0].item() | |
| # breakpoint() | |
| ## 2.PESQ | |
| min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1]) | |
| rawwav_16k_pesq=rawwav_16k[:,:min_len].squeeze(0) | |
| prewav_16k_pesq=prewav_16k[:,:min_len].squeeze(0) | |
| pesq_score = pesq(16000, rawwav_16k_pesq.cpu().numpy(), prewav_16k_pesq.cpu().numpy(), "wb", on_error=1) | |
| print("****PESQ",i,pesq_score) | |
| pesq_sumpre+=pesq_score | |
| # breakpoint() | |
| ## 3.F1-score | |
| min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1]) | |
| rawwav_16k_f1score=rawwav_16k[:,:min_len] | |
| prewav_16k_f1score=prewav_16k[:,:min_len] | |
| periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(rawwav_16k_f1score,prewav_16k_f1score) | |
| print("****f1",periodicity_loss, pitch_loss, f1_score,f1score_sumpre) | |
| if(math.isnan(f1_score)): | |
| f1score_filt+=1 | |
| print("*****",f1score_filt) | |
| else: | |
| f1score_sumpre+=f1_score | |
| # breakpoint() | |
| ## 4.STOI | |
| # # 针对重采样的ljspeech | |
| # rawwav_24k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=24000) | |
| # min_len=min(rawwav_24k.size()[1],prewav.size()[1]) | |
| # rawwav_stoi=rawwav_24k[:,:min_len].squeeze(0) | |
| # prewav_stoi=prewav[:,:min_len].squeeze(0) | |
| # tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),24000,extended=False) | |
| # print("****stoi",tmp_stoi) | |
| # stoi_sumpre.append(tmp_stoi) | |
| # # breakpoint() | |
| # 针对libritts采样率是24k的 | |
| min_len=min(rawwav.size()[1],prewav.size()[1]) | |
| rawwav_stoi=rawwav[:,:min_len].squeeze(0) | |
| prewav_stoi=prewav[:,:min_len].squeeze(0) | |
| tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),rawwav_sr,extended=False) | |
| print("****stoi",tmp_stoi) | |
| stoi_sumpre.append(tmp_stoi) | |
| print("*************UTMOS_raw",utmos_sumgt,utmos_sumgt/len(preaudio)) | |
| print("*************UTMOS_encodec",utmos_sumgt,utmos_sumencodec/len(preaudio)) | |
| print("*************PESQ:",pesq_sumpre,pesq_sumpre/len(preaudio)) | |
| print("*************F1_score:",f1score_sumpre,f1score_sumpre/(len(preaudio)-f1score_filt),f1score_filt) | |
| print("*************STOI:",np.mean(stoi_sumpre)) | |
| if __name__=="__main__": | |
| main() |