Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from src.loss.auxiliary import AuxiliaryLoss | |
| ################################################################################ | |
| # Control signal losses, for regularizing time-varying controls | |
| ################################################################################ | |
| class ControlSignalLoss(AuxiliaryLoss): | |
| """ | |
| Compute losses to regularize time-varying control signals. | |
| """ | |
| def __init__(self, | |
| reduction: str = 'none', | |
| loss: str = 'group-sparse-slowness', | |
| transpose: bool = False | |
| ): | |
| super().__init__(reduction) | |
| # select loss variant | |
| assert loss in ['l2-slowness', | |
| 'l1-slowness', | |
| 'group-sparse-slowness', | |
| 'l1/2-group-sparsity', | |
| 'l2', | |
| 'l1' | |
| ] | |
| self.loss = loss | |
| self.transpose = transpose | |
| def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): | |
| """Compute specified loss on given control signal""" | |
| # require (n_batch, time, channels) representation | |
| assert x.ndim == 3 | |
| b, t, c = x.shape | |
| # if specified, flip time and channel dimensions | |
| if self.transpose: | |
| x = x.permute(0, 2, 1) | |
| if self.loss == 'l2-slowness': | |
| loss = (1/((t - 1)*c))*torch.sum( | |
| torch.sum( | |
| torch.square( | |
| torch.diff(x, dim=1) | |
| ), | |
| dim=2, | |
| keepdim=True) + 1e-8, | |
| dim=1, | |
| keepdim=True | |
| ).reshape(b) | |
| elif self.loss == 'l1-slowness': | |
| loss = (1/((t - 1)*c))*torch.sum( | |
| torch.sum( | |
| torch.abs( | |
| torch.diff(x, dim=1) | |
| ), | |
| dim=2, | |
| keepdim=True) + 1e-8, | |
| dim=1, | |
| keepdim=True | |
| ).reshape(b) | |
| elif self.loss == 'group-sparse-slowness': | |
| loss = (1/((t - 1)*c))*torch.square( | |
| torch.sum( | |
| torch.sqrt( | |
| torch.sum( | |
| torch.square( | |
| torch.diff(x, dim=1) | |
| ), | |
| dim=2, | |
| keepdim=True) + 1e-8 | |
| ), | |
| dim=1, | |
| keepdim=True | |
| ) | |
| ).reshape(b) | |
| elif self.loss == 'l1/2-group-sparsity': | |
| loss = (1/((t - 1)*c))*torch.sum( | |
| torch.sum( | |
| torch.abs( | |
| torch.diff(x, dim=1) + 1e-8 | |
| )**0.5, | |
| dim=2, | |
| keepdim=True | |
| )**2, | |
| dim=1, | |
| keepdim=True | |
| ).reshape(b) | |
| elif self.loss == 'l2': | |
| loss = x.norm(dim=(1, 2), p=2).reshape(b) | |
| elif self.loss == 'l1': | |
| loss = x.norm(dim=(1, 2), p=1).reshape(b) | |
| else: | |
| raise ValueError(f'Invalid control-signal loss {self.loss}') | |
| return loss | |
| def set_reference(self, x_ref: torch.Tensor): | |
| pass | |