| import os |
| import torch |
| import numpy as np |
| import scipy.stats |
| from scipy.signal import butter, sosfilt |
|
|
| from pesq import pesq |
| from pystoi import stoi |
|
|
|
|
| def si_sdr_components(s_hat, s, n): |
| |
| alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2 |
| s_target = alpha_s * s |
|
|
| |
| alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2 |
| e_noise = alpha_n * n |
|
|
| |
| e_art = s_hat - s_target - e_noise |
| |
| return s_target, e_noise, e_art |
|
|
| def energy_ratios(s_hat, s, n): |
| s_target, e_noise, e_art = si_sdr_components(s_hat, s, n) |
|
|
| si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2) |
| si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2) |
| si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2) |
|
|
| return si_sdr, si_sir, si_sar |
|
|
| def mean_conf_int(data, confidence=0.95): |
| a = 1.0 * np.array(data) |
| n = len(a) |
| m, se = np.mean(a), scipy.stats.sem(a) |
| h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) |
| return m, h |
|
|
| class Method(): |
| def __init__(self, name, base_dir, metrics): |
| self.name = name |
| self.base_dir = base_dir |
| self.metrics = {} |
| |
| for i in range(len(metrics)): |
| metric = metrics[i] |
| value = [] |
| self.metrics[metric] = value |
| |
| def append(self, matric, value): |
| self.metrics[matric].append(value) |
|
|
| def get_mean_ci(self, metric): |
| return mean_conf_int(np.array(self.metrics[metric])) |
|
|
| def hp_filter(signal, cut_off=80, order=10, sr=16000): |
| factor = cut_off /sr * 2 |
| sos = butter(order, factor, 'hp', output='sos') |
| filtered = sosfilt(sos, signal) |
| return filtered |
|
|
| def si_sdr(s, s_hat): |
| alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2 |
| sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm( |
| alpha*s - s_hat)**2) |
| return sdr |
|
|
| def snr_dB(s,n): |
| s_power = 1/len(s)*np.sum(s**2) |
| n_power = 1/len(n)*np.sum(n**2) |
| snr_dB = 10*np.log10(s_power/n_power) |
| return snr_dB |
|
|
| def pad_spec(Y, mode="zero_pad"): |
| T = Y.size(3) |
| if T%64 !=0: |
| num_pad = 64-T%64 |
| else: |
| num_pad = 0 |
| if mode == "zero_pad": |
| pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0)) |
| elif mode == "reflection": |
| pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0)) |
| elif mode == "replication": |
| pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0)) |
| else: |
| raise NotImplementedError("This function hasn't been implemented yet.") |
| return pad2d(Y) |
|
|
| def ensure_dir(file_path): |
| directory = file_path |
| if not os.path.exists(directory): |
| os.makedirs(directory) |
|
|
|
|
| def print_metrics(x, y, x_hat_list, labels, sr=16000): |
| _si_sdr_mix = si_sdr(x, y) |
| _pesq_mix = pesq(sr, x, y, 'wb') |
| _estoi_mix = stoi(x, y, sr, extended=True) |
| print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}') |
| for i, x_hat in enumerate(x_hat_list): |
| _si_sdr = si_sdr(x, x_hat) |
| _pesq = pesq(sr, x, x_hat, 'wb') |
| _estoi = stoi(x, x_hat, sr, extended=True) |
| print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}') |
|
|
| def mean_std(data): |
| data = data[~np.isnan(data)] |
| mean = np.mean(data) |
| std = np.std(data) |
| return mean, std |
|
|
| def print_mean_std(data, decimal=2): |
| data = np.array(data) |
| data = data[~np.isnan(data)] |
| mean = np.mean(data) |
| std = np.std(data) |
| if decimal == 2: |
| string = f'{mean:.2f} ± {std:.2f}' |
| elif decimal == 1: |
| string = f'{mean:.1f} ± {std:.1f}' |
| return string |
|
|
| def set_torch_cuda_arch_list(): |
| if not torch.cuda.is_available(): |
| print("CUDA is not available. No GPUs found.") |
| return |
| |
| num_gpus = torch.cuda.device_count() |
| compute_capabilities = [] |
|
|
| for i in range(num_gpus): |
| cc_major, cc_minor = torch.cuda.get_device_capability(i) |
| cc = f"{cc_major}.{cc_minor}" |
| compute_capabilities.append(cc) |
| |
| cc_string = ";".join(compute_capabilities) |
| os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string |
| print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}") |