|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .layers import Decoder |
|
|
import torch.nn.functional as F |
|
|
from bert.modeling_bert import BertModel |
|
|
|
|
|
def dice_loss(inputs, targets): |
|
|
""" |
|
|
Compute the DICE loss, similar to generalized IOU for masks |
|
|
Args: |
|
|
inputs: A float tensor of arbitrary shape. |
|
|
The predictions for each example. |
|
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
|
classification label for each element in inputs |
|
|
(0 for the negative class and 1 for the positive class). |
|
|
""" |
|
|
|
|
|
inputs = inputs.sigmoid() |
|
|
inputs = inputs.flatten(1) |
|
|
targets = targets.flatten(1) |
|
|
numerator = 2 * (inputs * targets).sum(1) |
|
|
denominator = inputs.sum(-1) + targets.sum(-1) |
|
|
loss = 1 - (numerator + 1) / (denominator + 1) |
|
|
return loss.mean() |
|
|
|
|
|
def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): |
|
|
""" |
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. |
|
|
Args: |
|
|
inputs: A float tensor of arbitrary shape. |
|
|
The predictions for each example. |
|
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
|
classification label for each element in inputs |
|
|
(0 for the negative class and 1 for the positive class). |
|
|
alpha: (optional) Weighting factor in range (0,1) to balance |
|
|
positive vs negative examples. Default = -1 (no weighting). |
|
|
gamma: Exponent of the modulating factor (1 - p_t) to |
|
|
balance easy vs hard examples. |
|
|
Returns: |
|
|
Loss tensor |
|
|
""" |
|
|
|
|
|
prob = inputs.sigmoid() |
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
|
|
p_t = prob * targets + (1 - prob) * (1 - targets) |
|
|
loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
|
|
|
if alpha >= 0: |
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
|
|
loss = alpha_t * loss |
|
|
return loss.mean() |
|
|
|
|
|
|
|
|
class CGFormer(nn.Module): |
|
|
def __init__(self, backbone, args): |
|
|
super(CGFormer, self).__init__() |
|
|
self.backbone = backbone |
|
|
self.decoder = Decoder(args) |
|
|
self.text_encoder = BertModel.from_pretrained(args.bert) |
|
|
self.text_encoder.pooler = None |
|
|
|
|
|
def forward(self, x, text, l_mask, mask=None): |
|
|
input_shape = x.shape[-2:] |
|
|
l_feats = self.text_encoder(text, attention_mask=l_mask)[0] |
|
|
l_feats = l_feats.permute(0, 2, 1) |
|
|
l_mask = l_mask.unsqueeze(dim=-1) |
|
|
|
|
|
|
|
|
features = self.backbone(x, l_feats, l_mask) |
|
|
x_c1, x_c2, x_c3, x_c4 = features |
|
|
pred, maps = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask) |
|
|
pred = F.interpolate(pred, input_shape, mode='bilinear', align_corners=True) |
|
|
|
|
|
|
|
|
if self.training: |
|
|
loss = 0. |
|
|
mask = mask.unsqueeze(1).float() |
|
|
for m, lam in zip(maps, [0.001,0.01,0.1]): |
|
|
m = m[:,1].unsqueeze(1) |
|
|
if m.shape[-2:] != mask.shape[-2:]: |
|
|
mask_ = F.interpolate(mask, m.shape[-2:], mode='nearest').detach() |
|
|
loss += dice_loss(m, mask_) * lam |
|
|
loss += dice_loss(pred, mask) + sigmoid_focal_loss(pred, mask, alpha=-1, gamma=0) |
|
|
return pred.detach(), mask, loss |
|
|
else: |
|
|
return pred.detach(), maps |