ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import torch
import torch.nn.functional as F
from src.loss.auxiliary import AuxiliaryLoss
################################################################################
# Control signal losses, for regularizing time-varying controls
################################################################################
class ControlSignalLoss(AuxiliaryLoss):
"""
Compute losses to regularize time-varying control signals.
"""
def __init__(self,
reduction: str = 'none',
loss: str = 'group-sparse-slowness',
transpose: bool = False
):
super().__init__(reduction)
# select loss variant
assert loss in ['l2-slowness',
'l1-slowness',
'group-sparse-slowness',
'l1/2-group-sparsity',
'l2',
'l1'
]
self.loss = loss
self.transpose = transpose
def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None):
"""Compute specified loss on given control signal"""
# require (n_batch, time, channels) representation
assert x.ndim == 3
b, t, c = x.shape
# if specified, flip time and channel dimensions
if self.transpose:
x = x.permute(0, 2, 1)
if self.loss == 'l2-slowness':
loss = (1/((t - 1)*c))*torch.sum(
torch.sum(
torch.square(
torch.diff(x, dim=1)
),
dim=2,
keepdim=True) + 1e-8,
dim=1,
keepdim=True
).reshape(b)
elif self.loss == 'l1-slowness':
loss = (1/((t - 1)*c))*torch.sum(
torch.sum(
torch.abs(
torch.diff(x, dim=1)
),
dim=2,
keepdim=True) + 1e-8,
dim=1,
keepdim=True
).reshape(b)
elif self.loss == 'group-sparse-slowness':
loss = (1/((t - 1)*c))*torch.square(
torch.sum(
torch.sqrt(
torch.sum(
torch.square(
torch.diff(x, dim=1)
),
dim=2,
keepdim=True) + 1e-8
),
dim=1,
keepdim=True
)
).reshape(b)
elif self.loss == 'l1/2-group-sparsity':
loss = (1/((t - 1)*c))*torch.sum(
torch.sum(
torch.abs(
torch.diff(x, dim=1) + 1e-8
)**0.5,
dim=2,
keepdim=True
)**2,
dim=1,
keepdim=True
).reshape(b)
elif self.loss == 'l2':
loss = x.norm(dim=(1, 2), p=2).reshape(b)
elif self.loss == 'l1':
loss = x.norm(dim=(1, 2), p=1).reshape(b)
else:
raise ValueError(f'Invalid control-signal loss {self.loss}')
return loss
def set_reference(self, x_ref: torch.Tensor):
pass