import torch import torch.nn as nn from .layers import Decoder from .layers_v2 import Decoder_v2 import torch.nn.functional as F from bert.modeling_bert import BertModel def dice_loss(inputs, targets): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) targets = targets.flatten(1) numerator = 2 * (inputs * targets).sum(1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.mean() def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean() class CGFormer_sbert(nn.Module): def __init__(self, backbone, args): super(CGFormer_sbert, self).__init__() self.backbone = backbone self.mixup_lasttwo = args.mixup_lasttwo if self.mixup_lasttwo : self.decoder = Decoder_v2(args) else : self.decoder = Decoder(args) self.text_encoder = BertModel.from_pretrained(args.bert) self.text_encoder.pooler = None self.args = args self.filter_th = args.filter_threshold # image, text, l_mask, target, hardpos, hp_emb def forward(self, x, text, l_mask, mask=None, hp_bert_embs=None): verb_masks, cl_masks = [], [] rows_to_filter, cols_to_filter = None, None if self.training: for i in range(len(hp_bert_embs)): # if hp exists in current idx if ~torch.all(hp_bert_embs[i] == 0) : verb_masks.extend([1, 1]) cl_masks.extend([1, 0]) # orig, hp else: verb_masks.extend([0]) cl_masks.extend([1]) # filtering with hp_mask if hp_bert_embs.numel() > 0 and self.filter_th: hp_mask = ~torch.all(hp_bert_embs == 0, dim=1) hp_bert_embs = hp_bert_embs[hp_mask] norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True) normed_embs = hp_bert_embs / norms cosime_sim = torch.mm(normed_embs, normed_embs.T) rows_to_filter, cols_to_filter = torch.where(cosime_sim > self.filter_th) else: verb_masks = [0] * len(text) cl_masks = [1] * len(text) verb_masks = torch.tensor(verb_masks, dtype=torch.bool).to(x.device) cl_masks = torch.tensor(cl_masks, dtype=torch.bool).to(x.device) # print("inside the model") # print("x : ", x.shape) # print("text : ", text.shape) # print("l_mask : ", l_mask.shape) # print("mask : ", mask.shape) # print("hp_bert_embs : ", hp_bert_embs.shape) # print("verb_masks : ", verb_masks) # print("cl_masks : ", cl_masks) input_shape = x.shape[-2:] l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768) l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1) ########################## features = self.backbone(x, l_feats, l_mask) x_c1, x_c2, x_c3, x_c4 = features if self.mixup_lasttwo : pred, maps, fq_fuse = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask) metric_tensor = F.adaptive_avg_pool2d(fq_fuse, (1, 1)).view(fq_fuse.shape[0], fq_fuse.shape[1]) # print(fq_fuse.shape, metric_tensor.shape) else : pred, maps = self.decoder([x_c4, x_c3, x_c2, x_c1], l_feats, l_mask) metric_tensor = F.adaptive_avg_pool2d(x_c4, (1, 1)).view(x_c4.size(0), -1) pred = F.interpolate(pred, input_shape, mode='bilinear', align_corners=True) # loss if self.training: loss = 0. mask = mask.unsqueeze(1).float() for m, lam in zip(maps, [0.001,0.01,0.1]): m = m[:,1].unsqueeze(1) if m.shape[-2:] != mask.shape[-2:]: mask_ = F.interpolate(mask, m.shape[-2:], mode='nearest').detach() # loss += dice_loss(m, mask_, cl_masks) * lam loss += dice_loss(m[cl_masks], mask_[cl_masks]) * lam loss += dice_loss(pred[cl_masks], mask[cl_masks]) + sigmoid_focal_loss(pred[cl_masks], mask[cl_masks], alpha=-1, gamma=0) metric_loss = 0. if hp_bert_embs.numel() > 0: metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, rows_to_filter, cols_to_filter, self.args) loss += metric_loss * self.args.metric_loss_weight return pred.detach(), mask, loss else: return pred.detach(), maps def compute_metric_loss(self, metric_tensor, positive_verbs, rows_to_filter, cols_to_filter, args) : if args.loss_option == "ACL_verbonly" : raise ValueError("ACL_verbonly is not supported in CGFormer") elif args.loss_option == "ACE_verbonly" : metric_loss = self.UniAngularLogitContrastLoss(metric_tensor, positive_verbs, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) return metric_loss def return_mask(self, emb_distance, verb_mask=None, rows_to_filter=None, cols_to_filter=None): B_, B_ = emb_distance.shape positive_mask = torch.zeros_like(emb_distance) positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases if B_ < len(verb_mask): # If B_ equals to 2*K (double the number of verb phrase) for i in range(B_ // 2): positive_mask[2 * i, 2 * i + 1] = 1 positive_mask[2 * i + 1, 2 * i] = 1 else: # Process the case where we have a mix of sentences with and without verbs i = 0 while i < B_: if verb_mask[i] == 1: positive_mask[i, i + 1] = 1 positive_mask[i + 1, i] = 1 i += 2 else: i += 1 negative_mask = torch.ones_like(emb_distance) - positive_mask negative_mask = negative_mask.clone() if rows_to_filter is not None and cols_to_filter is not None : for row, col in zip(rows_to_filter, cols_to_filter): negative_mask[row * 2, col * 2] = 0 negative_mask[row * 2, col * 2 + 1] = 0 negative_mask[row * 2 + 1, col * 2] = 0 negative_mask[row * 2 + 1, col * 2 + 1] = 0 return positive_mask, negative_mask def UniAngularLogitContrastLoss(self, total_fq, verb_mask, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): _, HW = total_fq.shape if verbonly : emb = total_fq[verb_mask] assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." else : emb = torch.mean(total_fq, dim=-1) B_ = emb.shape[0] emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C) sim = nn.CosineSimilarity(dim=-1, eps=1e-6) sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_) sim_matrix = torch.clamp(sim_matrix, min=-0.999, max=0.999) # print("sim matrix : ", sim_matrix) margin_in_radians = m / 57.2958 # Convert degrees to radians # print("sim_matrix : ", sim_matrix) theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) # print("theta_matrix : ", theta_matrix) positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask, rows_to_filter, cols_to_filter) # print("? `positive_mask` requires_grad:", positive_mask.requires_grad, positive_mask.device) # print("? `negative_mask` requires_grad:", negative_mask.requires_grad, negative_mask.device) # print("positive_mask : ", positive_mask) # print("negative_mask : ", negative_mask) # print("? `positive_mask` requires_grad:", positive_mask.requires_grad, # "device:", positive_mask.device, "dtype:", positive_mask.dtype) # print("? `negative_mask` requires_grad:", negative_mask.requires_grad, # "device:", negative_mask.device, "dtype:", negative_mask.dtype) theta_with_margin = theta_matrix.clone() theta_with_margin[positive_mask.bool()] -= margin_in_radians logits = theta_with_margin / tau # Scale with temperature exp_logits = torch.exp(logits) pos_exp_logits = exp_logits * positive_mask pos_exp_logits = pos_exp_logits.sum(dim=-1) neg_exp_logits = exp_logits * negative_mask neg_exp_logits = neg_exp_logits.sum(dim=-1) total_exp_logits = pos_exp_logits + neg_exp_logits positive_loss = -torch.log(pos_exp_logits/ total_exp_logits) angular_loss = positive_loss.mean() # print("angular_loss : ", angular_loss) return angular_loss