File size: 4,193 Bytes
c817048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# 测试各种指标
import os
import glob
from UTMOS import UTMOSScore
from periodicity import calculate_periodicity_metrics
import torchaudio
from pesq import pesq
import numpy as np
import torch
import math
from pystoi import stoi

device=torch.device('cuda:0')

# 如果是ljspeech,需要更换路径,更换数据读取逻辑,更换stoi的采样率

def main():
    prepath="./Result/Minicodec/infer/dac_nq4_all"
    rawpath="./Data/libritts/test-clean"
    # rawpath="./Data/LJSpeech-1.1/wavs"
    preaudio = os.listdir(prepath)
    rawaudio = []

    UTMOS=UTMOSScore(device='cuda:0')
    
    # libritts
    for i in range(len(preaudio)):
        id1=preaudio[i].split('_')[0]
        id2=preaudio[i].split('_')[1]
        rawaudio.append(rawpath+"/"+id1+"/"+id2+"/"+preaudio[i])

    # # ljspeech
    # for i in range(len(preaudio)):
    #     rawaudio.append(rawpath+"/"+preaudio[i])

    utmos_sumgt=0
    utmos_sumencodec=0
    pesq_sumpre=0
    f1score_sumpre=0
    stoi_sumpre=[]
    f1score_filt=0

    for i in range(len(preaudio)):
        print(i)
        rawwav,rawwav_sr=torchaudio.load(rawaudio[i])
        prewav,prewav_sr=torchaudio.load(prepath+"/"+preaudio[i])
        # breakpoint()
        rawwav=rawwav.to(device)
        prewav=prewav.to(device)
        # print(rawwav.size(),prewav.size())
        # breakpoint()
        rawwav_16k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=16000)  #测试UTMOS的时候必须重采样
        prewav_16k=torchaudio.functional.resample(prewav, orig_freq=prewav_sr, new_freq=16000)


        # 1.UTMOS
        print("****UTMOS_raw",i,UTMOS.score(rawwav_16k.unsqueeze(1))[0].item())
        print("****UTMOS_encodec",i,UTMOS.score(prewav_16k.unsqueeze(1))[0].item())
        utmos_sumgt+=UTMOS.score(rawwav_16k.unsqueeze(1))[0].item()
        utmos_sumencodec+=UTMOS.score(prewav_16k.unsqueeze(1))[0].item()
    

        # breakpoint()

        ## 2.PESQ  
        min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1])
        rawwav_16k_pesq=rawwav_16k[:,:min_len].squeeze(0)
        prewav_16k_pesq=prewav_16k[:,:min_len].squeeze(0)
        pesq_score = pesq(16000, rawwav_16k_pesq.cpu().numpy(), prewav_16k_pesq.cpu().numpy(), "wb", on_error=1)
        print("****PESQ",i,pesq_score)
        pesq_sumpre+=pesq_score
        # breakpoint()

        ## 3.F1-score
        min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1])
        rawwav_16k_f1score=rawwav_16k[:,:min_len]
        prewav_16k_f1score=prewav_16k[:,:min_len]
        periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(rawwav_16k_f1score,prewav_16k_f1score)
        print("****f1",periodicity_loss, pitch_loss, f1_score,f1score_sumpre)
        if(math.isnan(f1_score)):
            f1score_filt+=1
            print("*****",f1score_filt)
        else:
            f1score_sumpre+=f1_score
        # breakpoint()


        ## 4.STOI
        # # 针对重采样的ljspeech
        # rawwav_24k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=24000)
        # min_len=min(rawwav_24k.size()[1],prewav.size()[1])
        # rawwav_stoi=rawwav_24k[:,:min_len].squeeze(0)
        # prewav_stoi=prewav[:,:min_len].squeeze(0)
        # tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),24000,extended=False)
        # print("****stoi",tmp_stoi)
        # stoi_sumpre.append(tmp_stoi)
        # # breakpoint()

        # 针对libritts采样率是24k的
        min_len=min(rawwav.size()[1],prewav.size()[1])
        rawwav_stoi=rawwav[:,:min_len].squeeze(0)
        prewav_stoi=prewav[:,:min_len].squeeze(0)
        tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),rawwav_sr,extended=False)
        print("****stoi",tmp_stoi)
        stoi_sumpre.append(tmp_stoi)

    print("*************UTMOS_raw",utmos_sumgt,utmos_sumgt/len(preaudio))
    print("*************UTMOS_encodec",utmos_sumgt,utmos_sumencodec/len(preaudio))
    print("*************PESQ:",pesq_sumpre,pesq_sumpre/len(preaudio))
    print("*************F1_score:",f1score_sumpre,f1score_sumpre/(len(preaudio)-f1score_filt),f1score_filt)
    print("*************STOI:",np.mean(stoi_sumpre))
    
    

if __name__=="__main__":
    main()