| | from typing import Optional, Dict |
| | import torch.nn as nn |
| | import torch |
| | from .schema import LossConfiguration |
| |
|
| |
|
| | def dice_loss(input: torch.Tensor, |
| | target: torch.Tensor, |
| | loss_mask: torch.Tensor, |
| | class_weights: Optional[torch.Tensor | bool], |
| | smooth=1e-5): |
| | ''' |
| | :param input: (B, H, W, C) Logits for each class |
| | :param target: (B, H, W, C) Ground truth class labels in one_hot |
| | :param loss_mask: (B, H, W) Mask indicating valid regions of the image |
| | :param class_weights: (C) Weights for each class |
| | :param smooth: Smoothing factor to avoid division by zero, default 1.0 |
| | ''' |
| | |
| | if isinstance(class_weights, torch.Tensor): |
| | class_weights = class_weights.unsqueeze(0) |
| | elif class_weights is None or class_weights == False: |
| | class_weights = torch.ones( |
| | 1, target.size(-1), dtype=target.dtype, device=target.device) |
| | elif class_weights == True: |
| | class_weights = target.sum(1) |
| | class_weights = torch.reciprocal(target.mean(1) + 1e-3) |
| | class_weights = class_weights.clamp(min=1e-5) |
| | |
| | class_weights *= (target.sum(1) != 0).float() |
| | class_weights.requires_grad = False |
| |
|
| | intersect = (2 * input * target) |
| | intersect = (intersect) + smooth |
| |
|
| | union = (input + target) |
| | union = (union) + smooth |
| |
|
| | loss = 1 - (intersect / union) |
| | loss *= class_weights.unsqueeze(0).unsqueeze(0) |
| | loss = loss.sum(-1) / class_weights.sum() |
| | loss *= loss_mask |
| | loss = loss.sum() / loss_mask.sum() |
| |
|
| | return loss |
| |
|
| |
|
| | class EnhancedLoss(nn.Module): |
| | def __init__( |
| | self, |
| | cfg: LossConfiguration, |
| | ): |
| | super(EnhancedLoss, self).__init__() |
| | self.num_classes = cfg.num_classes |
| | self.xent_weight = cfg.xent_weight |
| | self.focal = cfg.focal_loss |
| | self.focal_gamma = cfg.focal_loss_gamma |
| | self.dice_weight = cfg.dice_weight |
| | |
| |
|
| | if self.xent_weight == 0. and self.dice_weight == 0.: |
| | raise ValueError( |
| | "At least one of xent_weight and dice_weight must be greater than 0.") |
| | |
| | if self.xent_weight > 0.: |
| | self.xent_loss = nn.BCEWithLogitsLoss( |
| | reduction="none" |
| | ) |
| |
|
| | if self.dice_weight > 0.: |
| | self.dice_loss = dice_loss |
| |
|
| | if cfg.class_weights is not None and cfg.class_weights != True: |
| | self.register_buffer("class_weights", torch.tensor( |
| | cfg.class_weights), persistent=False) |
| | else: |
| | self.class_weights = cfg.class_weights |
| |
|
| | self.class_weights: Optional[torch.Tensor | bool] |
| |
|
| | self.requires_frustrum = cfg.requires_frustrum |
| | self.requires_flood_mask = cfg.requires_flood_mask |
| | self.label_smoothing = cfg.label_smoothing |
| |
|
| | def forward(self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]): |
| | ''' |
| | Args: |
| | pred: Dict containing the |
| | - output: (B, C, H, W) Probabilities for each class |
| | - valid_bev: (B, H, W) Mask indicating valid regions of the image |
| | - conf: (B, H, W) Confidence map |
| | data: Dict containing the |
| | - seg_masks: (B, H, W, C) Ground truth class labels, one-hot encoded |
| | - confidence_map: (B, H, W) Confidence map |
| | ''' |
| | loss = {} |
| |
|
| | probs = pred['output'].permute(0, 2, 3, 1) |
| | logits = pred['logits'].permute(0, 2, 3, 1) |
| | labels: torch.Tensor = data['seg_masks'] |
| |
|
| | loss_mask = torch.ones( |
| | labels.shape[:3], device=labels.device, dtype=labels.dtype) |
| |
|
| | if self.requires_frustrum: |
| | frustrum_mask = pred["valid_bev"][..., :-1] != 0 |
| | loss_mask = loss_mask * frustrum_mask.float() |
| |
|
| | if self.requires_flood_mask: |
| | flood_mask = data["flood_masks"] == 0 |
| | loss_mask = loss_mask * flood_mask.float() |
| |
|
| | if self.xent_weight > 0.: |
| |
|
| | if self.label_smoothing > 0.: |
| | labels_ls = labels.float().clone() |
| | labels_ls = labels_ls * \ |
| | (1 - self.label_smoothing) + \ |
| | self.label_smoothing / self.num_classes |
| |
|
| | xent_loss = self.xent_loss(logits, labels_ls) |
| | else: |
| | xent_loss = self.xent_loss(logits, labels) |
| |
|
| | if self.focal: |
| | pt = torch.exp(-xent_loss) |
| | xent_loss = (1 - pt) ** self.focal_gamma * xent_loss |
| |
|
| | xent_loss *= loss_mask.unsqueeze(-1) |
| | xent_loss = xent_loss.sum() / (loss_mask.sum() + 1e-5) |
| | loss['cross_entropy'] = xent_loss |
| | loss['total'] = xent_loss * self.xent_weight |
| |
|
| | if self.dice_weight > 0.: |
| | dloss = self.dice_loss( |
| | probs, labels, loss_mask, self.class_weights) |
| | loss['dice'] = dloss |
| |
|
| | if 'total' in loss: |
| | loss['total'] += dloss * self.dice_weight |
| | else: |
| | loss['total'] = dloss * self.dice_weight |
| |
|
| | return loss |
| |
|