|
|
import torch |
|
|
from torch.nn.modules.loss import _Loss |
|
|
|
|
|
class MultiSrcNegSDR(_Loss): |
|
|
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): |
|
|
super().__init__() |
|
|
|
|
|
assert sdr_type in ["snr", "sisdr", "sdsdr"] |
|
|
self.sdr_type = sdr_type |
|
|
self.zero_mean = zero_mean |
|
|
self.take_log = take_log |
|
|
self.EPS = 1e-8 |
|
|
|
|
|
def forward(self, ests, targets): |
|
|
if targets.size() != ests.size() or targets.ndim != 3: |
|
|
raise TypeError( |
|
|
f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead" |
|
|
) |
|
|
|
|
|
if self.zero_mean: |
|
|
mean_source = torch.mean(targets, dim=2, keepdim=True) |
|
|
mean_est = torch.mean(ests, dim=2, keepdim=True) |
|
|
targets = targets - mean_source |
|
|
ests = ests - mean_est |
|
|
|
|
|
if self.sdr_type in ["sisdr", "sdsdr"]: |
|
|
|
|
|
pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True) |
|
|
|
|
|
s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS |
|
|
|
|
|
scaled_targets = pair_wise_dot * targets / s_target_energy |
|
|
else: |
|
|
|
|
|
scaled_targets = targets |
|
|
if self.sdr_type in ["sdsdr", "snr"]: |
|
|
e_noise = ests - targets |
|
|
else: |
|
|
e_noise = ests - scaled_targets |
|
|
|
|
|
pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / ( |
|
|
torch.sum(e_noise ** 2, dim=2) + self.EPS |
|
|
) |
|
|
if self.take_log: |
|
|
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) |
|
|
return -torch.mean(pair_wise_sdr, dim=-1).mean(0) |