haoxiangsnr's picture
Upload folder using huggingface_hub
50de2e0 verified
import torch
from audiozen.acoustics.audio_feature import stft
from audiozen.loss import freq_MAE, mag_MAE
class SNRLoss(torch.nn.Module):
def __init__(self, tao=1e-3):
super().__init__()
self.tao = tao
self.eps = 1e-8
def l2norm(self, mat, keepdim=False):
return torch.norm(mat, dim=-1, keepdim=keepdim)
def forward(self, x, s, length_list=None):
if x.shape != s.shape:
raise RuntimeError(f"Dimension mismatch when calculate si-snr, {x.shape} vs {s.shape}")
if length_list is None:
loss = 10 * torch.log10(
self.l2norm(s - x) ** 2 + self.tao * self.l2norm(s) ** 2 + self.eps
) - 10 * torch.log10(self.l2norm(s) ** 2 + self.eps)
return loss.mean()
loss = 0
for i, length in enumerate(length_list):
x_i = x[i, :length]
s_i = s[i, :length]
loss += 10 * torch.log10(
self.l2norm(s_i - x_i) ** 2 + self.tao * self.l2norm(s_i) ** 2 + self.eps
) - 10 * torch.log10(self.l2norm(s_i) ** 2 + self.eps)
return loss.mean()
class FreqLoss(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, s, length_list=None):
if x.shape != s.shape:
raise RuntimeError(f"Dimension mismatch when calculate si-snr, {x.shape} vs {s.shape}")
if length_list is None:
loss = freq_MAE(s, x) + mag_MAE(s, x)
return loss
else:
loss = 0.0
for i, length in enumerate(length_list):
x_i = x[i, :length]
s_i = s[i, :length]
loss += freq_MAE(s_i, x_i) + mag_MAE(s_i, x_i)
return loss.mean()
class MultiResolutionL1SpecLoss(torch.nn.Module):
def __init__(
self, window_sz: list = [512, 1024, 2048, 256, 128], hop_sz: list = [256, 512, 1024, 128, 64], eps=1e-8
):
super().__init__()
self.eps = eps
self.window_sz = window_sz
self.hop_sz = hop_sz
def forward(self, est, ref):
# est: [B, T]
scaling_factor = torch.sum(est * ref, -1, keepdim=True) / (torch.sum(est**2, -1, keepdim=True) + self.eps)
time_domain_loss = torch.mean((est * scaling_factor - ref).abs(), dim=-1)
spectral_loss = torch.zeros_like(time_domain_loss)
for win, hop in zip(self.window_sz, self.hop_sz):
est_mag, *_ = stft(est, n_fft=win, hop_length=hop, win_length=win, window=None) # [B, F, T]
ref_mag, *_ = stft(ref, n_fft=win, hop_length=hop, win_length=win, window=None)
spectral_loss += torch.mean((est_mag - ref_mag).abs(), dim=(1, 2))
loss = 0.5 * time_domain_loss + 0.5 * spectral_loss / len(self.window_sz)
return loss.mean()
if __name__ == "__main__":
import numpy as np
x = np.random.rand(2, 16000)
s = np.random.rand(2, 16000)
x = torch.tensor(x, dtype=torch.float32)
s = torch.tensor(s, dtype=torch.float32)
loss = MultiResolutionL1SpecLoss()
print(loss(x, s))