Upload sgmse/util/inference.py
Browse files- sgmse/util/inference.py +64 -0
sgmse/util/inference.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchaudio import load
|
| 3 |
+
|
| 4 |
+
from pesq import pesq
|
| 5 |
+
from pystoi import stoi
|
| 6 |
+
|
| 7 |
+
from .other import si_sdr, pad_spec
|
| 8 |
+
|
| 9 |
+
# Settings
|
| 10 |
+
sr = 16000
|
| 11 |
+
snr = 0.5
|
| 12 |
+
N = 30
|
| 13 |
+
corrector_steps = 1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def evaluate_model(model, num_eval_files):
|
| 17 |
+
|
| 18 |
+
clean_files = model.data_module.valid_set.clean_files
|
| 19 |
+
noisy_files = model.data_module.valid_set.noisy_files
|
| 20 |
+
|
| 21 |
+
# Select test files uniformly accros validation files
|
| 22 |
+
total_num_files = len(clean_files)
|
| 23 |
+
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
|
| 24 |
+
clean_files = list(clean_files[i] for i in indices)
|
| 25 |
+
noisy_files = list(noisy_files[i] for i in indices)
|
| 26 |
+
|
| 27 |
+
_pesq = 0
|
| 28 |
+
_si_sdr = 0
|
| 29 |
+
_estoi = 0
|
| 30 |
+
# iterate over files
|
| 31 |
+
for (clean_file, noisy_file) in zip(clean_files, noisy_files):
|
| 32 |
+
# Load wavs
|
| 33 |
+
x, _ = load(clean_file)
|
| 34 |
+
y, _ = load(noisy_file)
|
| 35 |
+
T_orig = x.size(1)
|
| 36 |
+
|
| 37 |
+
# Normalize per utterance
|
| 38 |
+
norm_factor = y.abs().max()
|
| 39 |
+
y = y / norm_factor
|
| 40 |
+
|
| 41 |
+
# Prepare DNN input
|
| 42 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
|
| 43 |
+
Y = pad_spec(Y)
|
| 44 |
+
y = y * norm_factor
|
| 45 |
+
|
| 46 |
+
# Reverse sampling
|
| 47 |
+
sampler = model.get_pc_sampler(
|
| 48 |
+
'reverse_diffusion', 'ald', Y.cuda(), N=N,
|
| 49 |
+
corrector_steps=corrector_steps, snr=snr)
|
| 50 |
+
sample, _ = sampler()
|
| 51 |
+
|
| 52 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
| 53 |
+
x_hat = x_hat * norm_factor
|
| 54 |
+
|
| 55 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
| 56 |
+
x = x.squeeze().cpu().numpy()
|
| 57 |
+
y = y.squeeze().cpu().numpy()
|
| 58 |
+
|
| 59 |
+
_si_sdr += si_sdr(x, x_hat)
|
| 60 |
+
_pesq += pesq(sr, x, x_hat, 'wb')
|
| 61 |
+
_estoi += stoi(x, x_hat, sr, extended=True)
|
| 62 |
+
|
| 63 |
+
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
|
| 64 |
+
|