MRaCL / CGFormer /model /segmenter.py
dianecy's picture
Upload folder using huggingface_hub
ea1014e verified
import torch
import torch.nn as nn
from .layers import Decoder
import torch.nn.functional as F
from bert.modeling_bert import BertModel
def dice_loss(inputs, targets):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
targets = targets.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.mean()
def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean()
class CGFormer(nn.Module):
def __init__(self, backbone, args):
super(CGFormer, self).__init__()
self.backbone = backbone
self.decoder = Decoder(args)
self.text_encoder = BertModel.from_pretrained(args.bert)
self.text_encoder.pooler = None
def forward(self, x, text, l_mask, mask=None):
input_shape = x.shape[-2:]
l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768)
l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1)
##########################
features = self.backbone(x, l_feats, l_mask)
x_c1, x_c2, x_c3, x_c4 = features
pred, maps = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask)
pred = F.interpolate(pred, input_shape, mode='bilinear', align_corners=True)
# loss
if self.training:
loss = 0.
mask = mask.unsqueeze(1).float()
for m, lam in zip(maps, [0.001,0.01,0.1]):
m = m[:,1].unsqueeze(1)
if m.shape[-2:] != mask.shape[-2:]:
mask_ = F.interpolate(mask, m.shape[-2:], mode='nearest').detach()
loss += dice_loss(m, mask_) * lam
loss += dice_loss(pred, mask) + sigmoid_focal_loss(pred, mask, alpha=-1, gamma=0)
return pred.detach(), mask, loss
else:
return pred.detach(), maps