|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .layers import Decoder |
|
|
from .layers_v2 import Decoder_v2 |
|
|
import torch.nn.functional as F |
|
|
from bert.modeling_bert import BertModel |
|
|
|
|
|
def dice_loss(inputs, targets): |
|
|
""" |
|
|
Compute the DICE loss, similar to generalized IOU for masks |
|
|
Args: |
|
|
inputs: A float tensor of arbitrary shape. |
|
|
The predictions for each example. |
|
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
|
classification label for each element in inputs |
|
|
(0 for the negative class and 1 for the positive class). |
|
|
""" |
|
|
|
|
|
inputs = inputs.sigmoid() |
|
|
inputs = inputs.flatten(1) |
|
|
targets = targets.flatten(1) |
|
|
numerator = 2 * (inputs * targets).sum(1) |
|
|
denominator = inputs.sum(-1) + targets.sum(-1) |
|
|
loss = 1 - (numerator + 1) / (denominator + 1) |
|
|
return loss.mean() |
|
|
|
|
|
def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): |
|
|
""" |
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. |
|
|
Args: |
|
|
inputs: A float tensor of arbitrary shape. |
|
|
The predictions for each example. |
|
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
|
classification label for each element in inputs |
|
|
(0 for the negative class and 1 for the positive class). |
|
|
alpha: (optional) Weighting factor in range (0,1) to balance |
|
|
positive vs negative examples. Default = -1 (no weighting). |
|
|
gamma: Exponent of the modulating factor (1 - p_t) to |
|
|
balance easy vs hard examples. |
|
|
Returns: |
|
|
Loss tensor |
|
|
""" |
|
|
|
|
|
prob = inputs.sigmoid() |
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
|
|
p_t = prob * targets + (1 - prob) * (1 - targets) |
|
|
loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
|
|
|
if alpha >= 0: |
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
|
|
loss = alpha_t * loss |
|
|
return loss.mean() |
|
|
|
|
|
|
|
|
class CGFormer_RCC_sbert(nn.Module): |
|
|
def __init__(self, backbone, args): |
|
|
super(CGFormer_RCC_sbert, self).__init__() |
|
|
self.backbone = backbone |
|
|
self.mixup_lasttwo = args.mixup_lasttwo |
|
|
if self.mixup_lasttwo : |
|
|
self.decoder = Decoder_v2(args) |
|
|
else : |
|
|
self.decoder = Decoder(args) |
|
|
|
|
|
self.text_encoder = BertModel.from_pretrained(args.bert) |
|
|
self.text_encoder.pooler = None |
|
|
|
|
|
self.args = args |
|
|
if self.args.use_projections : |
|
|
self.projection_1 = nn.Linear(1536, 1024, bias=True) |
|
|
else : |
|
|
self.projection_1 = None |
|
|
|
|
|
self.use_projections = args.use_projections |
|
|
self.filter_th = args.filter_threshold |
|
|
|
|
|
|
|
|
def forward(self, x, text, l_mask, mask=None, hp_bert_embs=None): |
|
|
|
|
|
rows_to_filter, cols_to_filter = None, None |
|
|
|
|
|
if self.training: |
|
|
|
|
|
|
|
|
norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True) |
|
|
normed_embs = hp_bert_embs / norms |
|
|
cosime_sim = torch.mm(normed_embs, normed_embs.T) |
|
|
rows_to_filter, cols_to_filter = torch.where(cosime_sim > self.filter_th) |
|
|
|
|
|
|
|
|
input_shape = x.shape[-2:] |
|
|
l_feats = self.text_encoder(text, attention_mask=l_mask)[0] |
|
|
l_feats = l_feats.permute(0, 2, 1) |
|
|
l_mask = l_mask.unsqueeze(dim=-1) |
|
|
|
|
|
|
|
|
features = self.backbone(x, l_feats, l_mask) |
|
|
x_c1, x_c2, x_c3, x_c4 = features |
|
|
|
|
|
if self.mixup_lasttwo : |
|
|
pred, maps, fq_fuse = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask) |
|
|
metric_tensor = F.adaptive_avg_pool2d(fq_fuse, (1, 1)).view(fq_fuse.shape[0], fq_fuse.shape[1]) |
|
|
|
|
|
else : |
|
|
|
|
|
pred, maps = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask) |
|
|
if self.training : |
|
|
if self.use_projections : |
|
|
x_c3_proj = F.adaptive_avg_pool2d(x_c3, (1, 1)).view(x_c3.size(0), -1) |
|
|
x_c4_proj = F.adaptive_avg_pool2d(x_c4, (1, 1)).view(x_c4.size(0), -1) |
|
|
metric_tensor = torch.cat((x_c3_proj, x_c4_proj), dim=1) |
|
|
metric_tensor = self.projection_1(metric_tensor) |
|
|
|
|
|
else : |
|
|
metric_tensor = F.adaptive_avg_pool2d(x_c4, (1, 1)).view(x_c4.size(0), -1) |
|
|
|
|
|
pred = F.interpolate(pred, input_shape, mode='bilinear', align_corners=True) |
|
|
|
|
|
|
|
|
if self.training: |
|
|
loss = 0. |
|
|
mask = mask.unsqueeze(1).float() |
|
|
for m, lam in zip(maps, [0.001,0.01,0.1]): |
|
|
m = m[:,1].unsqueeze(1) |
|
|
if m.shape[-2:] != mask.shape[-2:]: |
|
|
mask_ = F.interpolate(mask, m.shape[-2:], mode='nearest').detach() |
|
|
loss += dice_loss(m, mask_) * lam |
|
|
loss += dice_loss(pred, mask) + sigmoid_focal_loss(pred, mask, alpha=-1, gamma=0) |
|
|
|
|
|
metric_loss = 0. |
|
|
if hp_bert_embs.numel() > 0 : |
|
|
metric_loss = self.compute_metric_loss(metric_tensor, rows_to_filter, cols_to_filter, self.args) |
|
|
loss += metric_loss * self.args.metric_loss_weight |
|
|
|
|
|
return pred.detach(), mask, loss |
|
|
else: |
|
|
return pred.detach(), maps |
|
|
|
|
|
|
|
|
def compute_metric_loss(self, metric_tensor, rows_to_filter, cols_to_filter, args) : |
|
|
if args.loss_option == "ACL_verbonly" : |
|
|
raise ValueError("ACL_verbonly is not supported in CGFormer") |
|
|
elif args.loss_option == "ACE_verbonly" : |
|
|
metric_loss = self.UniAngularLogitContrastLoss(metric_tensor, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) |
|
|
|
|
|
return metric_loss |
|
|
|
|
|
def return_mask(self, emb_distance, rows_to_filter=None, cols_to_filter=None): |
|
|
B_, B_ = emb_distance.shape |
|
|
positive_mask = torch.zeros_like(emb_distance) |
|
|
positive_mask.fill_diagonal_(1) |
|
|
negative_mask = torch.ones_like(emb_distance) - positive_mask |
|
|
negative_mask = negative_mask.clone() |
|
|
|
|
|
if rows_to_filter is not None and cols_to_filter is not None : |
|
|
for row, col in zip(rows_to_filter, cols_to_filter): |
|
|
negative_mask[row , col] = 0 |
|
|
return positive_mask, negative_mask |
|
|
|
|
|
|
|
|
def UniAngularLogitContrastLoss(self, total_fq, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
|
|
|
|
|
_, HW = total_fq.shape |
|
|
|
|
|
|
|
|
|
|
|
emb = torch.mean(total_fq, dim=1, keepdim=True) |
|
|
|
|
|
B_ = emb.shape[0] |
|
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
|
|
|
|
sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
|
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
|
|
sim_matrix = torch.clamp(sim_matrix, min=-0.999, max=0.999) |
|
|
|
|
|
|
|
|
margin_in_radians = m / 57.2958 |
|
|
|
|
|
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) |
|
|
|
|
|
|
|
|
positive_mask, negative_mask = self.return_mask(sim_matrix, rows_to_filter, cols_to_filter) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
theta_with_margin = theta_matrix.clone() |
|
|
theta_with_margin[positive_mask.bool()] -= margin_in_radians |
|
|
|
|
|
logits = theta_with_margin / tau |
|
|
exp_logits = torch.exp(logits) |
|
|
pos_exp_logits = exp_logits * positive_mask |
|
|
pos_exp_logits = pos_exp_logits.sum(dim=-1) |
|
|
neg_exp_logits = exp_logits * negative_mask |
|
|
neg_exp_logits = neg_exp_logits.sum(dim=-1) |
|
|
|
|
|
total_exp_logits = pos_exp_logits + neg_exp_logits |
|
|
|
|
|
positive_loss = -torch.log(pos_exp_logits/ total_exp_logits) |
|
|
angular_loss = positive_loss.mean() |
|
|
|
|
|
|
|
|
return angular_loss |