| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
|
|
| __all__ = ['PixelLoss'] |
|
|
|
|
| class PixelLoss(nn.Module): |
| """ |
| Pixel-wise loss between two images. |
| """ |
|
|
| def __init__(self, option: str = 'mse'): |
| super().__init__() |
| self.loss_fn = self._build_from_option(option) |
|
|
| @staticmethod |
| def _build_from_option(option: str, reduction: str = 'none'): |
| if option == 'mse': |
| return nn.MSELoss(reduction=reduction) |
| elif option == 'l1': |
| return nn.L1Loss(reduction=reduction) |
| else: |
| raise NotImplementedError(f'Unknown pixel loss option: {option}') |
|
|
| @torch.compile |
| def forward(self, x, y, conf_sigma=None, only_sym_conf=False): |
| """ |
| Assume images are channel first. |
| |
| Args: |
| x: [N, M, C, H, W] |
| y: [N, M, C, H, W] |
| |
| Returns: |
| Mean-reduced pixel loss across batch. |
| """ |
| N, M, C, H, W = x.shape |
| x = rearrange(x, "n m c h w -> (n m) c h w") |
| y = rearrange(y, "n m c h w -> (n m) c h w") |
| image_loss = self.loss_fn(x, y) |
|
|
| image_loss = image_loss.mean(dim=[1, 2, 3]) |
| batch_loss = image_loss.reshape(N, M).mean(dim=1) |
| all_loss = batch_loss.mean() |
| return all_loss |
|
|