|
|
import time |
|
|
import matplotlib as mpl |
|
|
mpl.use('Agg') |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.parallel |
|
|
import torch.optim |
|
|
from torch.autograd import Variable |
|
|
from torch.cuda.amp import autocast as autocast |
|
|
|
|
|
from model.model_sbert_gref import * |
|
|
from dataset.data_loader import * |
|
|
from utils.losses import * |
|
|
from utils.parsing_metrics import * |
|
|
from utils.utils import * |
|
|
from utils.utils import dice_loss, sigmoid_focal_loss |
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
|
print("use_cuda, ", use_cuda) |
|
|
|
|
|
|
|
|
def return_mask(emb_distance, verb_mask=None, rows_to_filter=None, cols_to_filter=None): |
|
|
B_, B_ = emb_distance.shape |
|
|
positive_mask = torch.zeros_like(emb_distance) |
|
|
positive_mask.fill_diagonal_(1) |
|
|
|
|
|
if B_ < len(verb_mask): |
|
|
|
|
|
for i in range(B_ // 2): |
|
|
positive_mask[2 * i, 2 * i + 1] = 1 |
|
|
positive_mask[2 * i + 1, 2 * i] = 1 |
|
|
else: |
|
|
|
|
|
i = 0 |
|
|
while i < B_: |
|
|
if verb_mask[i] == 1: |
|
|
positive_mask[i, i + 1] = 1 |
|
|
positive_mask[i + 1, i] = 1 |
|
|
i += 2 |
|
|
else: |
|
|
i += 1 |
|
|
negative_mask = torch.ones_like(emb_distance) - positive_mask |
|
|
negative_mask = negative_mask.clone() |
|
|
|
|
|
if rows_to_filter is not None and cols_to_filter is not None : |
|
|
for row, col in zip(rows_to_filter, cols_to_filter): |
|
|
negative_mask[row * 2, col * 2] = 0 |
|
|
negative_mask[row * 2, col * 2 + 1] = 0 |
|
|
negative_mask[row * 2 + 1, col * 2] = 0 |
|
|
negative_mask[row * 2 + 1, col * 2 + 1] = 0 |
|
|
|
|
|
return positive_mask, negative_mask |
|
|
|
|
|
|
|
|
def UniAngularLogitContrastLoss(total_fq, verb_mask, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
|
|
_, C, H, W = total_fq.shape |
|
|
|
|
|
|
|
|
if verbonly : |
|
|
B = total_fq[verb_mask].shape[0] |
|
|
emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) |
|
|
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." |
|
|
else : |
|
|
emb = torch.mean(total_fq, dim=-1) |
|
|
|
|
|
B_ = emb.shape[0] |
|
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
|
|
|
|
sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
|
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
|
|
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
|
|
|
|
|
margin_in_radians = m / 57.2958 |
|
|
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) |
|
|
|
|
|
|
|
|
|
|
|
positive_mask, negative_mask = return_mask(sim_matrix, verb_mask, rows_to_filter, cols_to_filter) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
theta_with_margin = theta_matrix.clone() |
|
|
theta_with_margin[positive_mask.bool()] -= margin_in_radians |
|
|
logits = theta_with_margin / tau |
|
|
|
|
|
|
|
|
exp_logits = torch.exp(logits) |
|
|
pos_exp_logits = exp_logits * positive_mask |
|
|
pos_exp_logits = pos_exp_logits.sum(dim=-1) |
|
|
neg_exp_logits = exp_logits * negative_mask |
|
|
neg_exp_logits = neg_exp_logits.sum(dim=-1) |
|
|
|
|
|
total_exp_logits = pos_exp_logits + neg_exp_logits |
|
|
|
|
|
positive_loss = -torch.log(pos_exp_logits/ total_exp_logits) |
|
|
angular_loss = positive_loss.mean() |
|
|
|
|
|
|
|
|
return angular_loss, B_ |
|
|
|
|
|
|
|
|
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger): |
|
|
print('train at epoch %d'%epoch) |
|
|
batch_time = AverageMeter() |
|
|
losses = AverageMeter() |
|
|
dice_losses = AverageMeter() |
|
|
sigmoid_focal_losses = AverageMeter() |
|
|
cos_losses = AverageMeter() |
|
|
model.train() |
|
|
end = time.time() |
|
|
|
|
|
|
|
|
mlw = args.metric_loss_weight |
|
|
metric_mode = args.metric_mode |
|
|
filter_thres = args.filter_thres |
|
|
metric_learning = args.metric_learning |
|
|
|
|
|
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader): |
|
|
B = imgs.size(0) |
|
|
|
|
|
hp_word_id = params['hp_word_id'] |
|
|
hp_word_mask = params['hp_word_mask'] |
|
|
hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1) |
|
|
pos_type = np.array(params['pos_type']) |
|
|
|
|
|
pos_mask = torch.tensor(np.where(pos_type == 'hardpos', 1, 0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
verb_masks = [] |
|
|
cl_masks = [] |
|
|
images = [] |
|
|
targets = [] |
|
|
sentences_ = [] |
|
|
sentences_masked_ = [] |
|
|
|
|
|
for idx in range(len(imgs)) : |
|
|
sentences_.append(word_id[idx]) |
|
|
sentences_masked_.append(word_mask[idx]) |
|
|
images.append(imgs[idx]) |
|
|
targets.append(seg_map[idx]) |
|
|
|
|
|
|
|
|
if pos_mask[idx] : |
|
|
verb_masks.extend([1, 1]) |
|
|
cl_masks.extend([1, 0]) |
|
|
sentences_.append(hp_word_id[idx]) |
|
|
sentences_masked_.append(hp_word_mask[idx]) |
|
|
images.append(imgs[idx]) |
|
|
targets.append(seg_map[idx]) |
|
|
else: |
|
|
verb_masks.append(0) |
|
|
cl_masks.append(1) |
|
|
|
|
|
imgs, seg_map, word_id, word_mask, verb_masks, cl_masks = \ |
|
|
torch.stack(images).cuda(rank, non_blocking=True),\ |
|
|
torch.stack(targets).cuda(rank, non_blocking=True),\ |
|
|
torch.stack(sentences_).cuda(rank, non_blocking=True),\ |
|
|
torch.stack(sentences_masked_).cuda(rank, non_blocking=True),\ |
|
|
torch.tensor(verb_masks, dtype=torch.bool).cuda(rank, non_blocking=True),\ |
|
|
torch.tensor(cl_masks, dtype=torch.bool).cuda(rank, non_blocking=True) |
|
|
|
|
|
image = Variable(imgs) |
|
|
word_id = Variable(word_id) |
|
|
word_mask = Variable(word_mask) |
|
|
seg_map = Variable(seg_map) |
|
|
verb_masks = Variable(verb_masks) |
|
|
cl_masks = Variable(cl_masks) |
|
|
|
|
|
if hp_bert_embs.numel() > 0 : |
|
|
mask = ~torch.all(hp_bert_embs == 0, dim=1) |
|
|
hp_bert_embs = hp_bert_embs[mask] |
|
|
|
|
|
norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True) |
|
|
normed_embs = hp_bert_embs / norms |
|
|
cosime_sim = torch.mm(normed_embs, normed_embs.T) |
|
|
rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with autocast(): |
|
|
mask_out_all, metric_tensors = model(image, word_id, word_mask) |
|
|
loss = 0. |
|
|
|
|
|
|
|
|
mask_out = mask_out_all[cl_masks] |
|
|
seg_map_cl = seg_map[cl_masks] |
|
|
|
|
|
mask_out_np = mask_out.data.cpu().numpy() |
|
|
seg_map_np = seg_map_cl.cpu().numpy() |
|
|
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
|
|
|
dice_loss_ = dice_loss(mask_out, seg_map_cl) |
|
|
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map_cl) |
|
|
|
|
|
dice_weight, focal_weight = 1.0, 1.0 |
|
|
loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_) |
|
|
|
|
|
|
|
|
if metric_learning and sum(pos_mask) > 1 : |
|
|
metric_weight = mlw |
|
|
|
|
|
metric_loss, NS = UniAngularLogitContrastLoss(metric_tensors, verb_masks, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) |
|
|
loss += metric_weight * metric_loss |
|
|
|
|
|
optimizer.zero_grad() |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
losses.update(loss.item(), B) |
|
|
dice_losses.update(dice_loss_.item(), B) |
|
|
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B) |
|
|
cos_losses.update(seg_iou.mean().item(), B) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if rank == 0 and batch_idx % args.print_freq == 0: |
|
|
print_str = 'Epoch: [{0}][{1}/{2}]\t' \ |
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
|
|
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \ |
|
|
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \ |
|
|
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \ |
|
|
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses) |
|
|
print(print_str) |
|
|
logger.info(print_str) |
|
|
|
|
|
return losses.avg |
|
|
|
|
|
def validate_epoch(args, val_loader, model, logger, mode='val'): |
|
|
print('begin test') |
|
|
batch_time = AverageMeter() |
|
|
miou = AverageMeter() |
|
|
miou_seg = AverageMeter() |
|
|
|
|
|
prec=dict() |
|
|
thresholds = np.arange(0.5, 1, 0.05) |
|
|
|
|
|
for thresh in thresholds: |
|
|
prec[thresh]= AverageMeter() |
|
|
|
|
|
model.eval() |
|
|
end = time.time() |
|
|
idx = 0 |
|
|
|
|
|
t_all = [] |
|
|
total_intersection = 0.0 |
|
|
total_union = 0.0 |
|
|
|
|
|
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader): |
|
|
|
|
|
imgs = imgs.cuda(0) |
|
|
word_id = word_id.cuda(0) |
|
|
word_mask = word_mask.cuda(0) |
|
|
seg_map = seg_map.cuda(0) |
|
|
image = Variable(imgs) |
|
|
word_id = Variable(word_id) |
|
|
word_mask = Variable(word_mask) |
|
|
seg_map = Variable(seg_map) |
|
|
|
|
|
t1 = time.time() |
|
|
with torch.no_grad(): |
|
|
mask_out, _ = model(image, word_id, word_mask) |
|
|
mask_out = mask_out.sigmoid() |
|
|
|
|
|
t2 = time.time() |
|
|
t_all.append(t2-t1) |
|
|
|
|
|
|
|
|
ih = seg_map.shape[-2] |
|
|
iw = seg_map.shape[-1] |
|
|
nh = int(ih * ratio) |
|
|
nw = int(iw * ratio) |
|
|
top, bottom = int(dh[0]), nh + int(dh[0]) |
|
|
left, right = int(dw[0]), nw + int(dw[0]) |
|
|
ratio = float(ratio) |
|
|
new_shape = (iw, ih) |
|
|
|
|
|
|
|
|
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0) |
|
|
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0) |
|
|
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0)) |
|
|
|
|
|
|
|
|
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0) |
|
|
mask_out = cv2.resize(mask_out, (args.size, args.size)) |
|
|
mask_out_np = mask_out[top:bottom, left:right] |
|
|
mask_out_np = cv2.resize(mask_out_np, new_shape) |
|
|
|
|
|
seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
|
|
|
miou_seg.update(seg_iou, imgs.size(0)) |
|
|
total_intersection += inter_sum |
|
|
total_union += union_sum |
|
|
|
|
|
for thresh in thresholds: |
|
|
prec[thresh].update(seg_prec[thresh], imgs.size(0)) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
if batch_idx % 1000 == 0: |
|
|
print_str = '[{0}/{1}]\t' \ |
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
|
|
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \ |
|
|
.format( \ |
|
|
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg) |
|
|
print(print_str) |
|
|
logger.info(print_str) |
|
|
idx = idx + 1 |
|
|
overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10) |
|
|
|
|
|
print("Mean IoU:", miou_seg.avg) |
|
|
print("Overall IoU:", overall_iou) |
|
|
logger.info("Mean IoU: %.4f" % miou_seg.avg) |
|
|
logger.info("Overall IoU: %.4f" % overall_iou) |
|
|
|
|
|
for thresh in thresholds: |
|
|
print("prec@%f: %f"%(thresh,float(prec[thresh].avg))) |
|
|
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg))) |
|
|
|
|
|
return miou_seg.avg, overall_iou, prec |
|
|
|