trenden commited on
Commit
510a82f
·
verified ·
1 Parent(s): b5c4a53

Upload sgmse/util/inference.py

Browse files
Files changed (1) hide show
  1. sgmse/util/inference.py +64 -0
sgmse/util/inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchaudio import load
3
+
4
+ from pesq import pesq
5
+ from pystoi import stoi
6
+
7
+ from .other import si_sdr, pad_spec
8
+
9
+ # Settings
10
+ sr = 16000
11
+ snr = 0.5
12
+ N = 30
13
+ corrector_steps = 1
14
+
15
+
16
+ def evaluate_model(model, num_eval_files):
17
+
18
+ clean_files = model.data_module.valid_set.clean_files
19
+ noisy_files = model.data_module.valid_set.noisy_files
20
+
21
+ # Select test files uniformly accros validation files
22
+ total_num_files = len(clean_files)
23
+ indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
24
+ clean_files = list(clean_files[i] for i in indices)
25
+ noisy_files = list(noisy_files[i] for i in indices)
26
+
27
+ _pesq = 0
28
+ _si_sdr = 0
29
+ _estoi = 0
30
+ # iterate over files
31
+ for (clean_file, noisy_file) in zip(clean_files, noisy_files):
32
+ # Load wavs
33
+ x, _ = load(clean_file)
34
+ y, _ = load(noisy_file)
35
+ T_orig = x.size(1)
36
+
37
+ # Normalize per utterance
38
+ norm_factor = y.abs().max()
39
+ y = y / norm_factor
40
+
41
+ # Prepare DNN input
42
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
43
+ Y = pad_spec(Y)
44
+ y = y * norm_factor
45
+
46
+ # Reverse sampling
47
+ sampler = model.get_pc_sampler(
48
+ 'reverse_diffusion', 'ald', Y.cuda(), N=N,
49
+ corrector_steps=corrector_steps, snr=snr)
50
+ sample, _ = sampler()
51
+
52
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
53
+ x_hat = x_hat * norm_factor
54
+
55
+ x_hat = x_hat.squeeze().cpu().numpy()
56
+ x = x.squeeze().cpu().numpy()
57
+ y = y.squeeze().cpu().numpy()
58
+
59
+ _si_sdr += si_sdr(x, x_hat)
60
+ _pesq += pesq(sr, x, x_hat, 'wb')
61
+ _estoi += stoi(x, x_hat, sr, extended=True)
62
+
63
+ return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
64
+