InPeerReview's picture
Upload 161 files
226675b verified
import numpy as np
import itertools
from typing import Any, Dict, List, Tuple, Union
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
"""
MaskFormer criterion.
"""
import torch
import torch.nn.functional as F
from torch import nn
from rscd.losses.loss_util.criterion import SetCriterion
from rscd.losses.loss_util.matcher import HungarianMatcher
class Mask2formerLoss(nn.Module):
def __init__(self, class_weight=2.0,
dice_weight=5.0,
mask_weight=5.0,
no_object_weight=0.1,
dec_layers = 10,
num_classes = 1,
device="cuda:0"):
super(Mask2formerLoss, self).__init__()
self.device = device
self.class_weight = class_weight
self.dice_weight = dice_weight
self.mask_weight = mask_weight
self.no_object_weight = no_object_weight
self.dec_layers = dec_layers
self.num_classes = num_classes
def forward(self, preds, target):
# building criterion
matcher = HungarianMatcher(
cost_class=self.class_weight,
cost_mask=self.mask_weight,
cost_dice=self.dice_weight,
num_points=12544,
)
weight_dict = {"loss_ce": self.class_weight, "loss_mask": self.mask_weight, "loss_dice": self.dice_weight}
aux_weight_dict = {}
for i in range(self.dec_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ["labels", "masks"]
criterion = SetCriterion(
num_classes=self.num_classes,
matcher=matcher,
weight_dict=weight_dict,
eos_coef=self.no_object_weight,
losses=losses,
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
device=torch.device(self.device)
)
preds["pred_masks"]= F.interpolate(
preds["pred_masks"],
scale_factor=(4, 4),
mode="bilinear",
align_corners=False,
)
for v in preds['aux_outputs']:
v['pred_masks'] = F.interpolate(
v["pred_masks"],
scale_factor=(4, 4),
mode="bilinear",
align_corners=False,
)
losses = criterion(preds, target)
weight_dict = criterion.weight_dict
loss_ce = 0.0
loss_dice = 0.0
loss_mask = 0.0
for k in list(losses.keys()):
if k in weight_dict:
losses[k] *= criterion.weight_dict[k]
if '_ce' in k:
loss_ce += losses[k]
elif '_dice' in k:
loss_dice += losses[k]
elif '_mask' in k:
loss_mask += losses[k]
else:
# remove this loss if not specified in `weight_dict`
losses.pop(k)
loss = loss_ce + loss_dice + loss_mask
return loss