breadlicker45 commited on
Commit
c817048
·
verified ·
1 Parent(s): a9beda6

Upload 5 files

Browse files
metrics/UTMOS.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import fairseq
4
+ import pytorch_lightning as pl
5
+ import requests
6
+ import torch
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+
10
+ UTMOS_CKPT_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt"
11
+ WAV2VEC_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt"
12
+
13
+ """
14
+ UTMOS score, automatic Mean Opinion Score (MOS) prediction system,
15
+ adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
16
+ """
17
+
18
+
19
+ class UTMOSScore:
20
+ """Predicting score for each audio clip."""
21
+
22
+ def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"):
23
+ self.device = device
24
+ filepath = os.path.join(os.path.dirname(__file__), ckpt_path)
25
+ if not os.path.exists(filepath):
26
+ download_file(UTMOS_CKPT_URL, filepath)
27
+ self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device)
28
+
29
+ def score(self, wavs: torch.tensor) -> torch.tensor:
30
+ """
31
+ Args:
32
+ wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
33
+ the model processes the input as a single audio clip. The model
34
+ performs batch processing when len(wavs) == 3.
35
+ """
36
+ if len(wavs.shape) == 1:
37
+ out_wavs = wavs.unsqueeze(0).unsqueeze(0)
38
+ elif len(wavs.shape) == 2:
39
+ out_wavs = wavs.unsqueeze(0)
40
+ elif len(wavs.shape) == 3:
41
+ out_wavs = wavs
42
+ else:
43
+ raise ValueError("Dimension of input tensor needs to be <= 3.")
44
+ bs = out_wavs.shape[0]
45
+ batch = {
46
+ "wav": out_wavs,
47
+ "domains": torch.zeros(bs, dtype=torch.int).to(self.device),
48
+ "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
49
+ }
50
+ with torch.no_grad():
51
+ output = self.model(batch)
52
+
53
+ return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
54
+
55
+
56
+ def download_file(url, filename):
57
+ """
58
+ Downloads a file from the given URL
59
+
60
+ Args:
61
+ url (str): The URL of the file to download.
62
+ filename (str): The name to save the file as.
63
+ """
64
+ print(f"Downloading file {filename}...")
65
+ response = requests.get(url, stream=True)
66
+ response.raise_for_status()
67
+
68
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
69
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
70
+
71
+ with open(filename, "wb") as f:
72
+ for chunk in response.iter_content(chunk_size=8192):
73
+ progress_bar.update(len(chunk))
74
+ f.write(chunk)
75
+
76
+ progress_bar.close()
77
+
78
+
79
+ def load_ssl_model(ckpt_path="wav2vec_small.pt"):
80
+ filepath = os.path.join(os.path.dirname(__file__), ckpt_path)
81
+ if not os.path.exists(filepath):
82
+ download_file(WAV2VEC_URL, filepath)
83
+ SSL_OUT_DIM = 768
84
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([filepath])
85
+ ssl_model = model[0]
86
+ ssl_model.remove_pretraining_modules()
87
+ return SSL_model(ssl_model, SSL_OUT_DIM)
88
+
89
+
90
+ class BaselineLightningModule(pl.LightningModule):
91
+ def __init__(self, cfg):
92
+ super().__init__()
93
+ self.cfg = cfg
94
+ self.construct_model()
95
+ self.save_hyperparameters()
96
+
97
+ def construct_model(self):
98
+ self.feature_extractors = nn.ModuleList(
99
+ [load_ssl_model(ckpt_path="wav2vec_small.pt"), DomainEmbedding(3, 128),]
100
+ )
101
+ output_dim = sum([feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
102
+ output_layers = [LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)]
103
+ output_dim = output_layers[-1].get_output_dim()
104
+ output_layers.append(
105
+ Projection(hidden_dim=2048, activation=torch.nn.ReLU(), range_clipping=False, input_dim=output_dim)
106
+ )
107
+
108
+ self.output_layers = nn.ModuleList(output_layers)
109
+
110
+ def forward(self, inputs):
111
+ outputs = {}
112
+ for feature_extractor in self.feature_extractors:
113
+ outputs.update(feature_extractor(inputs))
114
+ x = outputs
115
+ for output_layer in self.output_layers:
116
+ x = output_layer(x, inputs)
117
+ return x
118
+
119
+
120
+ class SSL_model(nn.Module):
121
+ def __init__(self, ssl_model, ssl_out_dim) -> None:
122
+ super(SSL_model, self).__init__()
123
+ self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
124
+
125
+ def forward(self, batch):
126
+ wav = batch["wav"]
127
+ wav = wav.squeeze(1) # [batches, audio_len]
128
+ res = self.ssl_model(wav, mask=False, features_only=True)
129
+ x = res["x"]
130
+ return {"ssl-feature": x}
131
+
132
+ def get_output_dim(self):
133
+ return self.ssl_out_dim
134
+
135
+
136
+ class DomainEmbedding(nn.Module):
137
+ def __init__(self, n_domains, domain_dim) -> None:
138
+ super().__init__()
139
+ self.embedding = nn.Embedding(n_domains, domain_dim)
140
+ self.output_dim = domain_dim
141
+
142
+ def forward(self, batch):
143
+ return {"domain-feature": self.embedding(batch["domains"])}
144
+
145
+ def get_output_dim(self):
146
+ return self.output_dim
147
+
148
+
149
+ class LDConditioner(nn.Module):
150
+ """
151
+ Conditions ssl output by listener embedding
152
+ """
153
+
154
+ def __init__(self, input_dim, judge_dim, num_judges=None):
155
+ super().__init__()
156
+ self.input_dim = input_dim
157
+ self.judge_dim = judge_dim
158
+ self.num_judges = num_judges
159
+ assert num_judges != None
160
+ self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
161
+ # concat [self.output_layer, phoneme features]
162
+
163
+ self.decoder_rnn = nn.LSTM(
164
+ input_size=self.input_dim + self.judge_dim,
165
+ hidden_size=512,
166
+ num_layers=1,
167
+ batch_first=True,
168
+ bidirectional=True,
169
+ ) # linear?
170
+ self.out_dim = self.decoder_rnn.hidden_size * 2
171
+
172
+ def get_output_dim(self):
173
+ return self.out_dim
174
+
175
+ def forward(self, x, batch):
176
+ judge_ids = batch["judge_id"]
177
+ if "phoneme-feature" in x.keys():
178
+ concatenated_feature = torch.cat(
179
+ (x["ssl-feature"], x["phoneme-feature"].unsqueeze(1).expand(-1, x["ssl-feature"].size(1), -1)), dim=2
180
+ )
181
+ else:
182
+ concatenated_feature = x["ssl-feature"]
183
+ if "domain-feature" in x.keys():
184
+ concatenated_feature = torch.cat(
185
+ (concatenated_feature, x["domain-feature"].unsqueeze(1).expand(-1, concatenated_feature.size(1), -1),),
186
+ dim=2,
187
+ )
188
+ if judge_ids != None:
189
+ concatenated_feature = torch.cat(
190
+ (
191
+ concatenated_feature,
192
+ self.judge_embedding(judge_ids).unsqueeze(1).expand(-1, concatenated_feature.size(1), -1),
193
+ ),
194
+ dim=2,
195
+ )
196
+ decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
197
+ return decoder_output
198
+
199
+
200
+ class Projection(nn.Module):
201
+ def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
202
+ super(Projection, self).__init__()
203
+ self.range_clipping = range_clipping
204
+ output_dim = 1
205
+ if range_clipping:
206
+ self.proj = nn.Tanh()
207
+
208
+ self.net = nn.Sequential(
209
+ nn.Linear(input_dim, hidden_dim), activation, nn.Dropout(0.3), nn.Linear(hidden_dim, output_dim),
210
+ )
211
+ self.output_dim = output_dim
212
+
213
+ def forward(self, x, batch):
214
+ output = self.net(x)
215
+
216
+ # range clipping
217
+ if self.range_clipping:
218
+ return self.proj(output) * 2.0 + 3
219
+ else:
220
+ return output
221
+
222
+ def get_output_dim(self):
223
+ return self.output_dim
metrics/__pycache__/UTMOS.cpython-310.pyc ADDED
Binary file (7.97 kB). View file
 
metrics/__pycache__/periodicity.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
metrics/infer.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 测试各种指标
2
+ import os
3
+ import glob
4
+ from UTMOS import UTMOSScore
5
+ from periodicity import calculate_periodicity_metrics
6
+ import torchaudio
7
+ from pesq import pesq
8
+ import numpy as np
9
+ import torch
10
+ import math
11
+ from pystoi import stoi
12
+
13
+ device=torch.device('cuda:0')
14
+
15
+ # 如果是ljspeech,需要更换路径,更换数据读取逻辑,更换stoi的采样率
16
+
17
+ def main():
18
+ prepath="./Result/Minicodec/infer/dac_nq4_all"
19
+ rawpath="./Data/libritts/test-clean"
20
+ # rawpath="./Data/LJSpeech-1.1/wavs"
21
+ preaudio = os.listdir(prepath)
22
+ rawaudio = []
23
+
24
+ UTMOS=UTMOSScore(device='cuda:0')
25
+
26
+ # libritts
27
+ for i in range(len(preaudio)):
28
+ id1=preaudio[i].split('_')[0]
29
+ id2=preaudio[i].split('_')[1]
30
+ rawaudio.append(rawpath+"/"+id1+"/"+id2+"/"+preaudio[i])
31
+
32
+ # # ljspeech
33
+ # for i in range(len(preaudio)):
34
+ # rawaudio.append(rawpath+"/"+preaudio[i])
35
+
36
+ utmos_sumgt=0
37
+ utmos_sumencodec=0
38
+ pesq_sumpre=0
39
+ f1score_sumpre=0
40
+ stoi_sumpre=[]
41
+ f1score_filt=0
42
+
43
+ for i in range(len(preaudio)):
44
+ print(i)
45
+ rawwav,rawwav_sr=torchaudio.load(rawaudio[i])
46
+ prewav,prewav_sr=torchaudio.load(prepath+"/"+preaudio[i])
47
+ # breakpoint()
48
+ rawwav=rawwav.to(device)
49
+ prewav=prewav.to(device)
50
+ # print(rawwav.size(),prewav.size())
51
+ # breakpoint()
52
+ rawwav_16k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=16000) #测试UTMOS的时候必须重采样
53
+ prewav_16k=torchaudio.functional.resample(prewav, orig_freq=prewav_sr, new_freq=16000)
54
+
55
+
56
+ # 1.UTMOS
57
+ print("****UTMOS_raw",i,UTMOS.score(rawwav_16k.unsqueeze(1))[0].item())
58
+ print("****UTMOS_encodec",i,UTMOS.score(prewav_16k.unsqueeze(1))[0].item())
59
+ utmos_sumgt+=UTMOS.score(rawwav_16k.unsqueeze(1))[0].item()
60
+ utmos_sumencodec+=UTMOS.score(prewav_16k.unsqueeze(1))[0].item()
61
+
62
+
63
+ # breakpoint()
64
+
65
+ ## 2.PESQ
66
+ min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1])
67
+ rawwav_16k_pesq=rawwav_16k[:,:min_len].squeeze(0)
68
+ prewav_16k_pesq=prewav_16k[:,:min_len].squeeze(0)
69
+ pesq_score = pesq(16000, rawwav_16k_pesq.cpu().numpy(), prewav_16k_pesq.cpu().numpy(), "wb", on_error=1)
70
+ print("****PESQ",i,pesq_score)
71
+ pesq_sumpre+=pesq_score
72
+ # breakpoint()
73
+
74
+ ## 3.F1-score
75
+ min_len=min(rawwav_16k.size()[1],prewav_16k.size()[1])
76
+ rawwav_16k_f1score=rawwav_16k[:,:min_len]
77
+ prewav_16k_f1score=prewav_16k[:,:min_len]
78
+ periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(rawwav_16k_f1score,prewav_16k_f1score)
79
+ print("****f1",periodicity_loss, pitch_loss, f1_score,f1score_sumpre)
80
+ if(math.isnan(f1_score)):
81
+ f1score_filt+=1
82
+ print("*****",f1score_filt)
83
+ else:
84
+ f1score_sumpre+=f1_score
85
+ # breakpoint()
86
+
87
+
88
+ ## 4.STOI
89
+ # # 针对重采样的ljspeech
90
+ # rawwav_24k=torchaudio.functional.resample(rawwav, orig_freq=rawwav_sr, new_freq=24000)
91
+ # min_len=min(rawwav_24k.size()[1],prewav.size()[1])
92
+ # rawwav_stoi=rawwav_24k[:,:min_len].squeeze(0)
93
+ # prewav_stoi=prewav[:,:min_len].squeeze(0)
94
+ # tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),24000,extended=False)
95
+ # print("****stoi",tmp_stoi)
96
+ # stoi_sumpre.append(tmp_stoi)
97
+ # # breakpoint()
98
+
99
+ # 针对libritts采样率是24k的
100
+ min_len=min(rawwav.size()[1],prewav.size()[1])
101
+ rawwav_stoi=rawwav[:,:min_len].squeeze(0)
102
+ prewav_stoi=prewav[:,:min_len].squeeze(0)
103
+ tmp_stoi=stoi(rawwav_stoi.cpu(),prewav_stoi.cpu(),rawwav_sr,extended=False)
104
+ print("****stoi",tmp_stoi)
105
+ stoi_sumpre.append(tmp_stoi)
106
+
107
+ print("*************UTMOS_raw",utmos_sumgt,utmos_sumgt/len(preaudio))
108
+ print("*************UTMOS_encodec",utmos_sumgt,utmos_sumencodec/len(preaudio))
109
+ print("*************PESQ:",pesq_sumpre,pesq_sumpre/len(preaudio))
110
+ print("*************F1_score:",f1score_sumpre,f1score_sumpre/(len(preaudio)-f1score_filt),f1score_filt)
111
+ print("*************STOI:",np.mean(stoi_sumpre))
112
+
113
+
114
+
115
+ if __name__=="__main__":
116
+ main()
metrics/periodicity.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
+ import torchcrepe
6
+ from torchcrepe.loudness import REF_DB
7
+
8
+ SILENCE_THRESHOLD = -60
9
+ UNVOICED_THRESHOLD = 0.21
10
+
11
+ """
12
+ Periodicity metrics adapted from https://github.com/descriptinc/cargan
13
+ """
14
+
15
+
16
+ def predict_pitch(
17
+ audio: torch.Tensor, silence_threshold: float = SILENCE_THRESHOLD, unvoiced_treshold: float = UNVOICED_THRESHOLD
18
+ ):
19
+ """
20
+ Predicts pitch and periodicity for the given audio.
21
+
22
+ Args:
23
+ audio (Tensor): The audio waveform.
24
+ silence_threshold (float): The threshold for silence detection.
25
+ unvoiced_treshold (float): The threshold for unvoiced detection.
26
+
27
+ Returns:
28
+ pitch (ndarray): The predicted pitch.
29
+ periodicity (ndarray): The predicted periodicity.
30
+ """
31
+ # torchcrepe inference
32
+ pitch, periodicity = torchcrepe.predict(
33
+ audio,
34
+ fmin=50.0,
35
+ fmax=550,
36
+ sample_rate=torchcrepe.SAMPLE_RATE,
37
+ model="full",
38
+ return_periodicity=True,
39
+ device=audio.device,
40
+ pad=False,
41
+ )
42
+ pitch = pitch.cpu().numpy()
43
+ periodicity = periodicity.cpu().numpy()
44
+
45
+ # Calculate dB-scaled spectrogram and set low energy frames to unvoiced
46
+ hop_length = torchcrepe.SAMPLE_RATE // 100 # default CREPE
47
+ stft = torchaudio.functional.spectrogram(
48
+ audio,
49
+ window=torch.hann_window(torchcrepe.WINDOW_SIZE, device=audio.device),
50
+ n_fft=torchcrepe.WINDOW_SIZE,
51
+ hop_length=hop_length,
52
+ win_length=torchcrepe.WINDOW_SIZE,
53
+ power=2,
54
+ normalized=False,
55
+ pad=0,
56
+ center=False,
57
+ )
58
+
59
+ # Perceptual weighting
60
+ freqs = librosa.fft_frequencies(sr=torchcrepe.SAMPLE_RATE, n_fft=torchcrepe.WINDOW_SIZE)
61
+ perceptual_stft = librosa.perceptual_weighting(stft.cpu().numpy(), freqs) - REF_DB
62
+ silence = perceptual_stft.mean(axis=1) < silence_threshold
63
+
64
+ periodicity[silence] = 0
65
+ pitch[periodicity < unvoiced_treshold] = torchcrepe.UNVOICED
66
+
67
+ return pitch, periodicity
68
+
69
+
70
+ def calculate_periodicity_metrics(y: torch.Tensor, y_hat: torch.Tensor):
71
+ """
72
+ Calculates periodicity metrics for the predicted and true audio data.
73
+
74
+ Args:
75
+ y (Tensor): The true audio data.
76
+ y_hat (Tensor): The predicted audio data.
77
+
78
+ Returns:
79
+ periodicity_loss (float): The periodicity loss.
80
+ pitch_loss (float): The pitch loss.
81
+ f1 (float): The F1 score for voiced/unvoiced classification
82
+ """
83
+ true_pitch, true_periodicity = predict_pitch(y)
84
+ pred_pitch, pred_periodicity = predict_pitch(y_hat)
85
+
86
+ true_voiced = ~np.isnan(true_pitch)
87
+ pred_voiced = ~np.isnan(pred_pitch)
88
+
89
+ periodicity_loss = np.sqrt(((pred_periodicity - true_periodicity) ** 2).mean(axis=1)).mean()
90
+
91
+ # Update pitch rmse
92
+ voiced = true_voiced & pred_voiced
93
+ difference_cents = 1200 * (np.log2(true_pitch[voiced]) - np.log2(pred_pitch[voiced]))
94
+ pitch_loss = np.sqrt((difference_cents ** 2).mean())
95
+
96
+ # voiced/unvoiced precision and recall
97
+ true_positives = (true_voiced & pred_voiced).sum()
98
+ false_positives = (~true_voiced & pred_voiced).sum()
99
+ false_negatives = (true_voiced & ~pred_voiced).sum()
100
+
101
+ precision = true_positives / (true_positives + false_positives)
102
+ recall = true_positives / (true_positives + false_negatives)
103
+ f1 = 2 * precision * recall / (precision + recall)
104
+
105
+ return periodicity_loss, pitch_loss, f1