| import os
|
| import torch
|
| import numpy as np
|
| import scipy.stats
|
| from scipy.signal import butter, sosfilt
|
|
|
| from pesq import pesq
|
| from pystoi import stoi
|
|
|
|
|
| def si_sdr_components(s_hat, s, n):
|
|
|
| alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
|
| s_target = alpha_s * s
|
|
|
|
|
| alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
|
| e_noise = alpha_n * n
|
|
|
|
|
| e_art = s_hat - s_target - e_noise
|
|
|
| return s_target, e_noise, e_art
|
|
|
| def energy_ratios(s_hat, s, n):
|
| s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
|
|
|
| si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
|
| si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
|
| si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
|
|
|
| return si_sdr, si_sir, si_sar
|
|
|
| def mean_conf_int(data, confidence=0.95):
|
| a = 1.0 * np.array(data)
|
| n = len(a)
|
| m, se = np.mean(a), scipy.stats.sem(a)
|
| h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
|
| return m, h
|
|
|
| class Method():
|
| def __init__(self, name, base_dir, metrics):
|
| self.name = name
|
| self.base_dir = base_dir
|
| self.metrics = {}
|
|
|
| for i in range(len(metrics)):
|
| metric = metrics[i]
|
| value = []
|
| self.metrics[metric] = value
|
|
|
| def append(self, matric, value):
|
| self.metrics[matric].append(value)
|
|
|
| def get_mean_ci(self, metric):
|
| return mean_conf_int(np.array(self.metrics[metric]))
|
|
|
| def hp_filter(signal, cut_off=80, order=10, sr=16000):
|
| factor = cut_off /sr * 2
|
| sos = butter(order, factor, 'hp', output='sos')
|
| filtered = sosfilt(sos, signal)
|
| return filtered
|
|
|
| def si_sdr(s, s_hat):
|
| alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
|
| sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
|
| alpha*s - s_hat)**2)
|
| return sdr
|
|
|
| def snr_dB(s,n):
|
| s_power = 1/len(s)*np.sum(s**2)
|
| n_power = 1/len(n)*np.sum(n**2)
|
| snr_dB = 10*np.log10(s_power/n_power)
|
| return snr_dB
|
|
|
| def pad_spec(Y, mode="zero_pad"):
|
| T = Y.size(3)
|
| if T%64 !=0:
|
| num_pad = 64-T%64
|
| else:
|
| num_pad = 0
|
| if mode == "zero_pad":
|
| pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
|
| elif mode == "reflection":
|
| pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
|
| elif mode == "replication":
|
| pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
|
| else:
|
| raise NotImplementedError("This function hasn't been implemented yet.")
|
| return pad2d(Y)
|
|
|
| def ensure_dir(file_path):
|
| directory = file_path
|
| if not os.path.exists(directory):
|
| os.makedirs(directory)
|
|
|
|
|
| def print_metrics(x, y, x_hat_list, labels, sr=16000):
|
| _si_sdr_mix = si_sdr(x, y)
|
| _pesq_mix = pesq(sr, x, y, 'wb')
|
| _estoi_mix = stoi(x, y, sr, extended=True)
|
| print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
|
| for i, x_hat in enumerate(x_hat_list):
|
| _si_sdr = si_sdr(x, x_hat)
|
| _pesq = pesq(sr, x, x_hat, 'wb')
|
| _estoi = stoi(x, x_hat, sr, extended=True)
|
| print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
|
|
|
| def mean_std(data):
|
| data = data[~np.isnan(data)]
|
| mean = np.mean(data)
|
| std = np.std(data)
|
| return mean, std
|
|
|
| def print_mean_std(data, decimal=2):
|
| data = np.array(data)
|
| data = data[~np.isnan(data)]
|
| mean = np.mean(data)
|
| std = np.std(data)
|
| if decimal == 2:
|
| string = f'{mean:.2f} ± {std:.2f}'
|
| elif decimal == 1:
|
| string = f'{mean:.1f} ± {std:.1f}'
|
| return string
|
|
|
| def set_torch_cuda_arch_list():
|
| if not torch.cuda.is_available():
|
| print("CUDA is not available. No GPUs found.")
|
| return
|
|
|
| num_gpus = torch.cuda.device_count()
|
| compute_capabilities = []
|
|
|
| for i in range(num_gpus):
|
| cc_major, cc_minor = torch.cuda.get_device_capability(i)
|
| cc = f"{cc_major}.{cc_minor}"
|
| compute_capabilities.append(cc)
|
|
|
| cc_string = ";".join(compute_capabilities)
|
| os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
|
| print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}") |