trenden commited on
Commit
9b9a995
·
verified ·
1 Parent(s): 510a82f

Upload sgmse/util/other.py

Browse files
Files changed (1) hide show
  1. sgmse/util/other.py +141 -0
sgmse/util/other.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import scipy.stats
5
+ from scipy.signal import butter, sosfilt
6
+
7
+ from pesq import pesq
8
+ from pystoi import stoi
9
+
10
+
11
+ def si_sdr_components(s_hat, s, n):
12
+ # s_target
13
+ alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
14
+ s_target = alpha_s * s
15
+
16
+ # e_noise
17
+ alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
18
+ e_noise = alpha_n * n
19
+
20
+ # e_art
21
+ e_art = s_hat - s_target - e_noise
22
+
23
+ return s_target, e_noise, e_art
24
+
25
+ def energy_ratios(s_hat, s, n):
26
+ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
27
+
28
+ si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
29
+ si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
30
+ si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
31
+
32
+ return si_sdr, si_sir, si_sar
33
+
34
+ def mean_conf_int(data, confidence=0.95):
35
+ a = 1.0 * np.array(data)
36
+ n = len(a)
37
+ m, se = np.mean(a), scipy.stats.sem(a)
38
+ h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
39
+ return m, h
40
+
41
+ class Method():
42
+ def __init__(self, name, base_dir, metrics):
43
+ self.name = name
44
+ self.base_dir = base_dir
45
+ self.metrics = {}
46
+
47
+ for i in range(len(metrics)):
48
+ metric = metrics[i]
49
+ value = []
50
+ self.metrics[metric] = value
51
+
52
+ def append(self, matric, value):
53
+ self.metrics[matric].append(value)
54
+
55
+ def get_mean_ci(self, metric):
56
+ return mean_conf_int(np.array(self.metrics[metric]))
57
+
58
+ def hp_filter(signal, cut_off=80, order=10, sr=16000):
59
+ factor = cut_off /sr * 2
60
+ sos = butter(order, factor, 'hp', output='sos')
61
+ filtered = sosfilt(sos, signal)
62
+ return filtered
63
+
64
+ def si_sdr(s, s_hat):
65
+ alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
66
+ sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
67
+ alpha*s - s_hat)**2)
68
+ return sdr
69
+
70
+ def snr_dB(s,n):
71
+ s_power = 1/len(s)*np.sum(s**2)
72
+ n_power = 1/len(n)*np.sum(n**2)
73
+ snr_dB = 10*np.log10(s_power/n_power)
74
+ return snr_dB
75
+
76
+ def pad_spec(Y, mode="zero_pad"):
77
+ T = Y.size(3)
78
+ if T%64 !=0:
79
+ num_pad = 64-T%64
80
+ else:
81
+ num_pad = 0
82
+ if mode == "zero_pad":
83
+ pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
84
+ elif mode == "reflection":
85
+ pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
86
+ elif mode == "replication":
87
+ pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
88
+ else:
89
+ raise NotImplementedError("This function hasn't been implemented yet.")
90
+ return pad2d(Y)
91
+
92
+ def ensure_dir(file_path):
93
+ directory = file_path
94
+ if not os.path.exists(directory):
95
+ os.makedirs(directory)
96
+
97
+
98
+ def print_metrics(x, y, x_hat_list, labels, sr=16000):
99
+ _si_sdr_mix = si_sdr(x, y)
100
+ _pesq_mix = pesq(sr, x, y, 'wb')
101
+ _estoi_mix = stoi(x, y, sr, extended=True)
102
+ print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
103
+ for i, x_hat in enumerate(x_hat_list):
104
+ _si_sdr = si_sdr(x, x_hat)
105
+ _pesq = pesq(sr, x, x_hat, 'wb')
106
+ _estoi = stoi(x, x_hat, sr, extended=True)
107
+ print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
108
+
109
+ def mean_std(data):
110
+ data = data[~np.isnan(data)]
111
+ mean = np.mean(data)
112
+ std = np.std(data)
113
+ return mean, std
114
+
115
+ def print_mean_std(data, decimal=2):
116
+ data = np.array(data)
117
+ data = data[~np.isnan(data)]
118
+ mean = np.mean(data)
119
+ std = np.std(data)
120
+ if decimal == 2:
121
+ string = f'{mean:.2f} ± {std:.2f}'
122
+ elif decimal == 1:
123
+ string = f'{mean:.1f} ± {std:.1f}'
124
+ return string
125
+
126
+ def set_torch_cuda_arch_list():
127
+ if not torch.cuda.is_available():
128
+ print("CUDA is not available. No GPUs found.")
129
+ return
130
+
131
+ num_gpus = torch.cuda.device_count()
132
+ compute_capabilities = []
133
+
134
+ for i in range(num_gpus):
135
+ cc_major, cc_minor = torch.cuda.get_device_capability(i)
136
+ cc = f"{cc_major}.{cc_minor}"
137
+ compute_capabilities.append(cc)
138
+
139
+ cc_string = ";".join(compute_capabilities)
140
+ os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
141
+ print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")