File size: 3,268 Bytes
226675b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import itertools
from typing import Any, Dict, List, Tuple, Union



# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
"""

MaskFormer criterion.

"""
import torch
import torch.nn.functional as F
from torch import nn

from rscd.losses.loss_util.criterion import SetCriterion
from rscd.losses.loss_util.matcher import HungarianMatcher

class Mask2formerLoss(nn.Module):
    def __init__(self, class_weight=2.0,

                 dice_weight=5.0,

                 mask_weight=5.0,

                 no_object_weight=0.1,

                 dec_layers = 10,

                 num_classes = 1,

                 device="cuda:0"):
        super(Mask2formerLoss, self).__init__()
        self.device = device
        self.class_weight = class_weight
        self.dice_weight = dice_weight
        self.mask_weight = mask_weight
        self.no_object_weight = no_object_weight
        self.dec_layers = dec_layers
        self.num_classes = num_classes

    def forward(self, preds, target):
        # building criterion
        matcher = HungarianMatcher(
            cost_class=self.class_weight,
            cost_mask=self.mask_weight,
            cost_dice=self.dice_weight,
            num_points=12544,
        )

        weight_dict = {"loss_ce": self.class_weight, "loss_mask": self.mask_weight, "loss_dice": self.dice_weight}
        aux_weight_dict = {}
        for i in range(self.dec_layers - 1):
            aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)

        losses = ["labels", "masks"]
        criterion = SetCriterion(
            num_classes=self.num_classes,
            matcher=matcher,
            weight_dict=weight_dict,
            eos_coef=self.no_object_weight,
            losses=losses,
            num_points=12544,
            oversample_ratio=3.0,
            importance_sample_ratio=0.75,
            device=torch.device(self.device)
        )

        preds["pred_masks"]= F.interpolate(
            preds["pred_masks"],
            scale_factor=(4, 4),
            mode="bilinear",
            align_corners=False,
        )

        for v in preds['aux_outputs']:
            v['pred_masks'] = F.interpolate(
            v["pred_masks"],
            scale_factor=(4, 4),
            mode="bilinear",
            align_corners=False,
        )

        losses = criterion(preds, target)
        weight_dict = criterion.weight_dict
                    
        loss_ce = 0.0
        loss_dice = 0.0
        loss_mask = 0.0
        for k in list(losses.keys()):
            if k in weight_dict:
                losses[k] *= criterion.weight_dict[k]
                if '_ce' in k:
                    loss_ce += losses[k]
                elif '_dice' in k:
                    loss_dice += losses[k]
                elif '_mask' in k:
                    loss_mask += losses[k]
            else:
                # remove this loss if not specified in `weight_dict`
                losses.pop(k)
        loss = loss_ce + loss_dice + loss_mask
        return loss