| from typing import Any, Dict, List, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
| from torch.nn import BCEWithLogitsLoss |
|
|
| from yolo.config.config import Config, LossConfig |
| from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou |
| from yolo.utils.logger import logger |
|
|
|
|
| class BCELoss(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| |
| self.bce = BCEWithLogitsLoss(reduction="none") |
|
|
| def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any: |
| return self.bce(predicts_cls, targets_cls).sum() / cls_norm |
|
|
|
|
| class BoxLoss(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
|
|
| def forward( |
| self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor |
| ) -> Any: |
| valid_bbox = valid_masks[..., None].expand(-1, -1, 4) |
| picked_predict = predicts_bbox[valid_bbox].view(-1, 4) |
| picked_targets = targets_bbox[valid_bbox].view(-1, 4) |
|
|
| iou = calculate_iou(picked_predict, picked_targets, "ciou").diag() |
| loss_iou = 1.0 - iou |
| loss_iou = (loss_iou * box_norm).sum() / cls_norm |
| return loss_iou |
|
|
|
|
| class DFLoss(nn.Module): |
| def __init__(self, vec2box: Vec2Box, reg_max: int) -> None: |
| super().__init__() |
| self.anchors_norm = (vec2box.anchor_grid / vec2box.scaler[:, None])[None] |
| self.reg_max = reg_max |
|
|
| def forward( |
| self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor |
| ) -> Any: |
| valid_bbox = valid_masks[..., None].expand(-1, -1, 4) |
| bbox_lt, bbox_rb = targets_bbox.chunk(2, -1) |
| targets_dist = torch.cat(((self.anchors_norm - bbox_lt), (bbox_rb - self.anchors_norm)), -1).clamp( |
| 0, self.reg_max - 1.01 |
| ) |
| picked_targets = targets_dist[valid_bbox].view(-1) |
| picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max) |
|
|
| label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1 |
| weight_left, weight_right = label_right - picked_targets, picked_targets - label_left |
|
|
| loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none") |
| loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none") |
| loss_dfl = loss_left * weight_left + loss_right * weight_right |
| loss_dfl = loss_dfl.view(-1, 4).mean(-1) |
| loss_dfl = (loss_dfl * box_norm).sum() / cls_norm |
| return loss_dfl |
|
|
|
|
| class YOLOLoss: |
| def __init__(self, loss_cfg: LossConfig, vec2box: Vec2Box, class_num: int = 80, reg_max: int = 16) -> None: |
| self.class_num = class_num |
| self.vec2box = vec2box |
|
|
| self.cls = BCELoss() |
| self.dfl = DFLoss(vec2box, reg_max) |
| self.iou = BoxLoss() |
|
|
| self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box, reg_max) |
|
|
| def separate_anchor(self, anchors): |
| """ |
| separate anchor and bbouding box |
| """ |
| anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1) |
| anchors_box = anchors_box / self.vec2box.scaler[None, :, None] |
| return anchors_cls, anchors_box |
|
|
| def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]: |
| predicts_cls, predicts_anc, predicts_box = predicts |
| |
| align_targets, valid_masks = self.matcher(targets, (predicts_cls.detach(), predicts_box.detach())) |
|
|
| targets_cls, targets_bbox = self.separate_anchor(align_targets) |
| predicts_box = predicts_box / self.vec2box.scaler[None, :, None] |
|
|
| cls_norm = max(targets_cls.sum(), 1) |
| box_norm = targets_cls.sum(-1)[valid_masks] |
|
|
| |
| loss_cls = self.cls(predicts_cls, targets_cls, cls_norm) |
| |
| loss_iou = self.iou(predicts_box, targets_bbox, valid_masks, box_norm, cls_norm) |
| |
| loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm) |
|
|
| return loss_iou, loss_dfl, loss_cls |
|
|
|
|
| class DualLoss: |
| def __init__(self, cfg: Config, vec2box) -> None: |
| loss_cfg = cfg.task.loss |
| self.loss = YOLOLoss(loss_cfg, vec2box, class_num=cfg.dataset.class_num, reg_max=cfg.model.anchor.reg_max) |
|
|
| self.aux_rate = loss_cfg.aux |
|
|
| self.iou_rate = loss_cfg.objective["BoxLoss"] |
| self.dfl_rate = loss_cfg.objective["DFLoss"] |
| self.cls_rate = loss_cfg.objective["BCELoss"] |
|
|
| def __call__( |
| self, aux_predicts: List[Tensor], main_predicts: List[Tensor], targets: Tensor |
| ) -> Tuple[Tensor, Dict[str, float]]: |
| |
| aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets) |
| main_iou, main_dfl, main_cls = self.loss(main_predicts, targets) |
|
|
| total_loss = [ |
| self.iou_rate * (aux_iou * self.aux_rate + main_iou), |
| self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl), |
| self.cls_rate * (aux_cls * self.aux_rate + main_cls), |
| ] |
| loss_dict = { |
| f"Loss/{name}Loss": value.detach().item() for name, value in zip(["Box", "DFL", "BCE"], total_loss) |
| } |
| return sum(total_loss), loss_dict |
|
|
|
|
| def create_loss_function(cfg: Config, vec2box) -> DualLoss: |
| |
| loss_function = DualLoss(cfg, vec2box) |
| logger.info(":white_check_mark: Success load loss function") |
| return loss_function |
|
|