|
|
import time |
|
|
import matplotlib as mpl |
|
|
mpl.use('Agg') |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.parallel |
|
|
import torch.optim |
|
|
from torch.autograd import Variable |
|
|
from torch.cuda.amp import autocast as autocast |
|
|
|
|
|
from model.model import * |
|
|
from dataset.data_loader import * |
|
|
from utils.losses import * |
|
|
from utils.parsing_metrics import * |
|
|
from utils.utils import * |
|
|
from utils.utils import dice_loss, sigmoid_focal_loss |
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
|
print("use_cuda, ", use_cuda) |
|
|
|
|
|
|
|
|
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger): |
|
|
print('train at epoch %d'%epoch) |
|
|
batch_time = AverageMeter() |
|
|
losses = AverageMeter() |
|
|
dice_losses = AverageMeter() |
|
|
sigmoid_focal_losses = AverageMeter() |
|
|
cos_losses = AverageMeter() |
|
|
model.train() |
|
|
end = time.time() |
|
|
|
|
|
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map) in enumerate(train_loader): |
|
|
imgs = imgs.cuda(rank, non_blocking=True) |
|
|
word_id = word_id.cuda(rank, non_blocking=True) |
|
|
word_mask = word_mask.cuda(rank, non_blocking=True) |
|
|
seg_map = seg_map.cuda(rank, non_blocking=True) |
|
|
image = Variable(imgs) |
|
|
word_id = Variable(word_id) |
|
|
word_mask = Variable(word_mask) |
|
|
seg_map = Variable(seg_map) |
|
|
|
|
|
with autocast(): |
|
|
mask_out = model(image, word_id, word_mask) |
|
|
loss = 0. |
|
|
|
|
|
mask_out_np = mask_out.data.cpu().numpy() |
|
|
seg_map_np = seg_map.cpu().numpy() |
|
|
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
|
|
|
dice_loss_ = dice_loss(mask_out, seg_map) |
|
|
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map) |
|
|
|
|
|
loss += dice_loss_ + sigmoid_focal_loss_ |
|
|
|
|
|
optimizer.zero_grad() |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
losses.update(loss.item(), imgs.size(0)) |
|
|
dice_losses.update(dice_loss_.item(), imgs.size(0)) |
|
|
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), imgs.size(0)) |
|
|
cos_losses.update(seg_iou.mean().item(), imgs.size(0)) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if rank == 0 and batch_idx % args.print_freq == 0: |
|
|
print_str = 'Epoch: [{0}][{1}/{2}]\t' \ |
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
|
|
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \ |
|
|
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \ |
|
|
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \ |
|
|
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses) |
|
|
print(print_str) |
|
|
logger.info(print_str) |
|
|
|
|
|
return losses.avg |
|
|
|
|
|
def validate_epoch(args, val_loader, model, logger, mode='val'): |
|
|
print('begin test') |
|
|
batch_time = AverageMeter() |
|
|
miou = AverageMeter() |
|
|
miou_seg = AverageMeter() |
|
|
|
|
|
prec=dict() |
|
|
thresholds = np.arange(0.5, 1, 0.05) |
|
|
|
|
|
for thresh in thresholds: |
|
|
prec[thresh]= AverageMeter() |
|
|
|
|
|
model.eval() |
|
|
end = time.time() |
|
|
idx = 0 |
|
|
|
|
|
t_all = [] |
|
|
total_intersection = 0.0 |
|
|
total_union = 0.0 |
|
|
|
|
|
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader): |
|
|
|
|
|
imgs = imgs.cuda(0) |
|
|
word_id = word_id.cuda(0) |
|
|
word_mask = word_mask.cuda(0) |
|
|
seg_map = seg_map.cuda(0) |
|
|
image = Variable(imgs) |
|
|
word_id = Variable(word_id) |
|
|
word_mask = Variable(word_mask) |
|
|
seg_map = Variable(seg_map) |
|
|
|
|
|
t1 = time.time() |
|
|
with torch.no_grad(): |
|
|
mask_out = model(image, word_id, word_mask) |
|
|
mask_out = mask_out.sigmoid() |
|
|
|
|
|
t2 = time.time() |
|
|
t_all.append(t2-t1) |
|
|
|
|
|
|
|
|
ih = seg_map.shape[-2] |
|
|
iw = seg_map.shape[-1] |
|
|
nh = int(ih * ratio) |
|
|
nw = int(iw * ratio) |
|
|
top, bottom = int(dh[0]), nh + int(dh[0]) |
|
|
left, right = int(dw[0]), nw + int(dw[0]) |
|
|
ratio = float(ratio) |
|
|
new_shape = (iw, ih) |
|
|
|
|
|
|
|
|
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0) |
|
|
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0) |
|
|
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0)) |
|
|
|
|
|
|
|
|
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0) |
|
|
mask_out = cv2.resize(mask_out, (args.size, args.size)) |
|
|
mask_out_np = mask_out[top:bottom, left:right] |
|
|
mask_out_np = cv2.resize(mask_out_np, new_shape) |
|
|
|
|
|
seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
|
|
|
miou_seg.update(seg_iou, imgs.size(0)) |
|
|
total_intersection += inter_sum |
|
|
total_union += union_sum |
|
|
|
|
|
for thresh in thresholds: |
|
|
prec[thresh].update(seg_prec[thresh], imgs.size(0)) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
if batch_idx % 1000 == 0: |
|
|
print_str = '[{0}/{1}]\t' \ |
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
|
|
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \ |
|
|
.format( \ |
|
|
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg) |
|
|
print(print_str) |
|
|
logger.info(print_str) |
|
|
idx = idx + 1 |
|
|
overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10) |
|
|
|
|
|
print("Mean IoU:", miou_seg.avg) |
|
|
print("Overall IoU:", overall_iou) |
|
|
logger.info("Mean IoU: %.4f" % miou_seg.avg) |
|
|
logger.info("Overall IoU: %.4f" % overall_iou) |
|
|
|
|
|
for thresh in thresholds: |
|
|
print("prec@%f: %f"%(thresh,float(prec[thresh].avg))) |
|
|
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg))) |
|
|
|
|
|
return miou_seg.avg, overall_iou, prec |
|
|
|
|
|
|