import time import matplotlib as mpl mpl.use('Agg') import numpy as np import torch import torch.nn.parallel import torch.optim from torch.autograd import Variable from torch.cuda.amp import autocast as autocast from model.model_sbert_gref import * from dataset.data_loader import * from utils.losses import * from utils.parsing_metrics import * from utils.utils import * from utils.utils import dice_loss, sigmoid_focal_loss use_cuda = torch.cuda.is_available() print("use_cuda, ", use_cuda) def return_mask(emb_distance, rows_to_filter=None, cols_to_filter=None): B_, B_ = emb_distance.shape positive_mask = torch.zeros_like(emb_distance) positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases negative_mask = torch.ones_like(emb_distance) - positive_mask negative_mask = negative_mask.clone() if rows_to_filter is not None and cols_to_filter is not None : for row, col in zip(rows_to_filter, cols_to_filter): negative_mask[row , col] = 0 return positive_mask, negative_mask def UniAngularLogitContrastLoss(total_fq, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): _, C, H, W = total_fq.shape B = total_fq.shape[0] emb = torch.mean(total_fq, dim=(-1, -2)).reshape(B, C) B_ = emb.shape[0] emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C) sim = nn.CosineSimilarity(dim=-1, eps=1e-6) sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_) sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) margin_in_radians = m / 57.2958 # Convert degrees to radians theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) # print("sim_matrix : ", sim_matrix) # print("theta_matrix : ", theta_matrix) positive_mask, negative_mask = return_mask(sim_matrix, rows_to_filter, cols_to_filter) theta_with_margin = theta_matrix.clone() theta_with_margin[positive_mask.bool()] -= margin_in_radians logits = theta_with_margin / tau # Scale with temperature # Compute exp logits for softmax exp_logits = torch.exp(logits) pos_exp_logits = exp_logits * positive_mask pos_exp_logits = pos_exp_logits.sum(dim=-1) neg_exp_logits = exp_logits * negative_mask neg_exp_logits = neg_exp_logits.sum(dim=-1) total_exp_logits = pos_exp_logits + neg_exp_logits positive_loss = -torch.log(pos_exp_logits/ total_exp_logits) angular_loss = positive_loss.mean() return angular_loss def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger): print('train at epoch %d'%epoch) batch_time = AverageMeter() losses = AverageMeter() dice_losses = AverageMeter() sigmoid_focal_losses = AverageMeter() cos_losses = AverageMeter() model.train() end = time.time() # argument for verb-centric radial contrastive loss mlw = args.metric_loss_weight metric_mode = args.metric_mode filter_thres = args.filter_thres metric_learning = args.metric_learning for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader): B = imgs.size(0) # Original Batch size hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1) imgs = imgs.cuda(rank, non_blocking=True) word_id = word_id.cuda(rank, non_blocking=True) word_mask = word_mask.cuda(rank, non_blocking=True) seg_map = seg_map.cuda(rank, non_blocking=True) image = Variable(imgs) word_id = Variable(word_id) word_mask = Variable(word_mask) seg_map = Variable(seg_map) if hp_bert_embs.numel() > 0 : # print(hp_bert_embs.shape, hp_bert_embs.requires_grad, hp_bert_embs.device) norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True) normed_embs = hp_bert_embs / norms cosime_sim = torch.mm(normed_embs, normed_embs.T) rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres) with autocast(): mask_out, metric_tensors = model(image, word_id, word_mask) loss = 0. # get mask and seg_map for calculating existing loss function (iou loss, dice loss, sigmoid focal loss) mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208] seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208] seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh) dice_loss_ = dice_loss(mask_out, seg_map) sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map) dice_weight, focal_weight = 1.0, 1.0 loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_) # get angular contrastive loss, which involves original & verb pharase pairs (only for pairs where hardpos verb phrase exists) if metric_learning : metric_weight = mlw metric_loss = UniAngularLogitContrastLoss(metric_tensors, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) loss += metric_weight * metric_loss optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() losses.update(loss.item(), B) dice_losses.update(dice_loss_.item(), B) sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B) cos_losses.update(seg_iou.mean().item(), B) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if rank == 0 and batch_idx % args.print_freq == 0: print_str = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \ 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \ 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \ .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses) print(print_str) logger.info(print_str) return losses.avg def validate_epoch(args, val_loader, model, logger, mode='val'): print('begin test') batch_time = AverageMeter() miou = AverageMeter() miou_seg = AverageMeter() prec=dict() thresholds = np.arange(0.5, 1, 0.05) for thresh in thresholds: prec[thresh]= AverageMeter() model.eval() end = time.time() idx = 0 t_all = [] total_intersection = 0.0 total_union = 0.0 for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader): imgs = imgs.cuda(0) word_id = word_id.cuda(0) word_mask = word_mask.cuda(0) seg_map = seg_map.cuda(0) image = Variable(imgs) word_id = Variable(word_id) word_mask = Variable(word_mask) seg_map = Variable(seg_map) t1 = time.time() with torch.no_grad(): mask_out, _ = model(image, word_id, word_mask) mask_out = mask_out.sigmoid() t2 = time.time() t_all.append(t2-t1) ## test: convert pred, gt box to original scale with meta-info ih = seg_map.shape[-2] iw = seg_map.shape[-1] nh = int(ih * ratio) nw = int(iw * ratio) top, bottom = int(dh[0]), nh + int(dh[0]) left, right = int(dw[0]), nw + int(dw[0]) ratio = float(ratio) new_shape = (iw, ih) ## revert image for visualization seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0) seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC) img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0) img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC) img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0)) # seg mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0) mask_out = cv2.resize(mask_out, (args.size, args.size)) mask_out_np = mask_out[top:bottom, left:right] mask_out_np = cv2.resize(mask_out_np, new_shape) # seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh) seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh) miou_seg.update(seg_iou, imgs.size(0)) total_intersection += inter_sum total_union += union_sum for thresh in thresholds: prec[thresh].update(seg_prec[thresh], imgs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % 1000 == 0: print_str = '[{0}/{1}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \ .format( \ batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg) print(print_str) logger.info(print_str) idx = idx + 1 overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10) print("Mean IoU:", miou_seg.avg) print("Overall IoU:", overall_iou) logger.info("Mean IoU: %.4f" % miou_seg.avg) logger.info("Overall IoU: %.4f" % overall_iou) for thresh in thresholds: print("prec@%f: %f"%(thresh,float(prec[thresh].avg))) logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg))) # logger.info("%f,%f"%(float(miou.avg), miou_seg.avg)) return miou_seg.avg, overall_iou, prec