|
|
import time |
|
|
import matplotlib as mpl |
|
|
mpl.use('Agg') |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.parallel |
|
|
import torch.optim |
|
|
from torch.autograd import Variable |
|
|
from torch.cuda.amp import autocast as autocast |
|
|
|
|
|
from model.model_sbert_gref import * |
|
|
from dataset.data_loader import * |
|
|
from utils.losses import * |
|
|
from utils.parsing_metrics import * |
|
|
from utils.utils import * |
|
|
from utils.utils import dice_loss, sigmoid_focal_loss |
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
|
print("use_cuda, ", use_cuda) |
|
|
|
|
|
|
|
|
def return_mask(emb_distance, rows_to_filter=None, cols_to_filter=None): |
|
|
B_, B_ = emb_distance.shape |
|
|
positive_mask = torch.zeros_like(emb_distance) |
|
|
positive_mask.fill_diagonal_(1) |
|
|
negative_mask = torch.ones_like(emb_distance) - positive_mask |
|
|
negative_mask = negative_mask.clone() |
|
|
|
|
|
if rows_to_filter is not None and cols_to_filter is not None : |
|
|
for row, col in zip(rows_to_filter, cols_to_filter): |
|
|
negative_mask[row , col] = 0 |
|
|
return positive_mask, negative_mask |
|
|
|
|
|
|
|
|
def UniAngularLogitContrastLoss(total_fq, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
|
|
_, C, H, W = total_fq.shape |
|
|
|
|
|
B = total_fq.shape[0] |
|
|
emb = torch.mean(total_fq, dim=(-1, -2)).reshape(B, C) |
|
|
|
|
|
B_ = emb.shape[0] |
|
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
|
|
|
|
sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
|
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
|
|
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
|
|
|
|
|
margin_in_radians = m / 57.2958 |
|
|
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) |
|
|
|
|
|
|
|
|
positive_mask, negative_mask = return_mask(sim_matrix, rows_to_filter, cols_to_filter) |
|
|
|
|
|
theta_with_margin = theta_matrix.clone() |
|
|
theta_with_margin[positive_mask.bool()] -= margin_in_radians |
|
|
logits = theta_with_margin / tau |
|
|
|
|
|
|
|
|
exp_logits = torch.exp(logits) |
|
|
pos_exp_logits = exp_logits * positive_mask |
|
|
pos_exp_logits = pos_exp_logits.sum(dim=-1) |
|
|
neg_exp_logits = exp_logits * negative_mask |
|
|
neg_exp_logits = neg_exp_logits.sum(dim=-1) |
|
|
|
|
|
total_exp_logits = pos_exp_logits + neg_exp_logits |
|
|
|
|
|
|
|
|
positive_loss = -torch.log(pos_exp_logits/ total_exp_logits) |
|
|
angular_loss = positive_loss.mean() |
|
|
|
|
|
return angular_loss |
|
|
|
|
|
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger): |
|
|
print('train at epoch %d'%epoch) |
|
|
batch_time = AverageMeter() |
|
|
losses = AverageMeter() |
|
|
dice_losses = AverageMeter() |
|
|
sigmoid_focal_losses = AverageMeter() |
|
|
cos_losses = AverageMeter() |
|
|
model.train() |
|
|
end = time.time() |
|
|
|
|
|
|
|
|
mlw = args.metric_loss_weight |
|
|
metric_mode = args.metric_mode |
|
|
filter_thres = args.filter_thres |
|
|
metric_learning = args.metric_learning |
|
|
|
|
|
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader): |
|
|
B = imgs.size(0) |
|
|
hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1) |
|
|
|
|
|
imgs = imgs.cuda(rank, non_blocking=True) |
|
|
word_id = word_id.cuda(rank, non_blocking=True) |
|
|
word_mask = word_mask.cuda(rank, non_blocking=True) |
|
|
seg_map = seg_map.cuda(rank, non_blocking=True) |
|
|
image = Variable(imgs) |
|
|
word_id = Variable(word_id) |
|
|
word_mask = Variable(word_mask) |
|
|
seg_map = Variable(seg_map) |
|
|
|
|
|
if hp_bert_embs.numel() > 0 : |
|
|
|
|
|
norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True) |
|
|
normed_embs = hp_bert_embs / norms |
|
|
cosime_sim = torch.mm(normed_embs, normed_embs.T) |
|
|
rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres) |
|
|
|
|
|
|
|
|
with autocast(): |
|
|
mask_out, metric_tensors = model(image, word_id, word_mask) |
|
|
loss = 0. |
|
|
|
|
|
|
|
|
|
|
|
mask_out_np = mask_out.data.cpu().numpy() |
|
|
seg_map_np = seg_map.cpu().numpy() |
|
|
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
|
|
|
dice_loss_ = dice_loss(mask_out, seg_map) |
|
|
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map) |
|
|
|
|
|
dice_weight, focal_weight = 1.0, 1.0 |
|
|
loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_) |
|
|
|
|
|
|
|
|
if metric_learning : |
|
|
metric_weight = mlw |
|
|
metric_loss = UniAngularLogitContrastLoss(metric_tensors, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) |
|
|
|
|
|
loss += metric_weight * metric_loss |
|
|
|
|
|
optimizer.zero_grad() |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
losses.update(loss.item(), B) |
|
|
dice_losses.update(dice_loss_.item(), B) |
|
|
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B) |
|
|
cos_losses.update(seg_iou.mean().item(), B) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if rank == 0 and batch_idx % args.print_freq == 0: |
|
|
print_str = 'Epoch: [{0}][{1}/{2}]\t' \ |
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
|
|
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \ |
|
|
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \ |
|
|
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \ |
|
|
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses) |
|
|
print(print_str) |
|
|
logger.info(print_str) |
|
|
|
|
|
return losses.avg |
|
|
|
|
|
def validate_epoch(args, val_loader, model, logger, mode='val'): |
|
|
print('begin test') |
|
|
batch_time = AverageMeter() |
|
|
miou = AverageMeter() |
|
|
miou_seg = AverageMeter() |
|
|
|
|
|
prec=dict() |
|
|
thresholds = np.arange(0.5, 1, 0.05) |
|
|
|
|
|
for thresh in thresholds: |
|
|
prec[thresh]= AverageMeter() |
|
|
|
|
|
model.eval() |
|
|
end = time.time() |
|
|
idx = 0 |
|
|
|
|
|
t_all = [] |
|
|
total_intersection = 0.0 |
|
|
total_union = 0.0 |
|
|
|
|
|
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader): |
|
|
|
|
|
imgs = imgs.cuda(0) |
|
|
word_id = word_id.cuda(0) |
|
|
word_mask = word_mask.cuda(0) |
|
|
seg_map = seg_map.cuda(0) |
|
|
image = Variable(imgs) |
|
|
word_id = Variable(word_id) |
|
|
word_mask = Variable(word_mask) |
|
|
seg_map = Variable(seg_map) |
|
|
|
|
|
t1 = time.time() |
|
|
with torch.no_grad(): |
|
|
mask_out, _ = model(image, word_id, word_mask) |
|
|
mask_out = mask_out.sigmoid() |
|
|
|
|
|
t2 = time.time() |
|
|
t_all.append(t2-t1) |
|
|
|
|
|
|
|
|
ih = seg_map.shape[-2] |
|
|
iw = seg_map.shape[-1] |
|
|
nh = int(ih * ratio) |
|
|
nw = int(iw * ratio) |
|
|
top, bottom = int(dh[0]), nh + int(dh[0]) |
|
|
left, right = int(dw[0]), nw + int(dw[0]) |
|
|
ratio = float(ratio) |
|
|
new_shape = (iw, ih) |
|
|
|
|
|
|
|
|
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0) |
|
|
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0) |
|
|
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0)) |
|
|
|
|
|
|
|
|
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0) |
|
|
mask_out = cv2.resize(mask_out, (args.size, args.size)) |
|
|
mask_out_np = mask_out[top:bottom, left:right] |
|
|
mask_out_np = cv2.resize(mask_out_np, new_shape) |
|
|
|
|
|
seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
|
|
|
miou_seg.update(seg_iou, imgs.size(0)) |
|
|
total_intersection += inter_sum |
|
|
total_union += union_sum |
|
|
|
|
|
for thresh in thresholds: |
|
|
prec[thresh].update(seg_prec[thresh], imgs.size(0)) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
if batch_idx % 1000 == 0: |
|
|
print_str = '[{0}/{1}]\t' \ |
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
|
|
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \ |
|
|
.format( \ |
|
|
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg) |
|
|
print(print_str) |
|
|
logger.info(print_str) |
|
|
idx = idx + 1 |
|
|
overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10) |
|
|
|
|
|
print("Mean IoU:", miou_seg.avg) |
|
|
print("Overall IoU:", overall_iou) |
|
|
logger.info("Mean IoU: %.4f" % miou_seg.avg) |
|
|
logger.info("Overall IoU: %.4f" % overall_iou) |
|
|
|
|
|
for thresh in thresholds: |
|
|
print("prec@%f: %f"%(thresh,float(prec[thresh].avg))) |
|
|
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg))) |
|
|
|
|
|
return miou_seg.avg, overall_iou, prec |
|
|
|
|
|
|