|
|
import os |
|
|
import time |
|
|
from tqdm import tqdm |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.cuda.amp as amp |
|
|
import torch.distributed as dist |
|
|
import torch.nn.functional as F |
|
|
import wandb |
|
|
from PIL import Image |
|
|
from loguru import logger |
|
|
from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather, concat_all_gather_varsize, trainMetricGPU) |
|
|
|
|
|
|
|
|
def train(train_loader, model, optimizer, scheduler, scaler, epoch, args): |
|
|
batch_time = AverageMeter('Batch', ':2.2f') |
|
|
data_time = AverageMeter('Data', ':2.2f') |
|
|
lr = AverageMeter('Lr', ':1.6f') |
|
|
loss_meter = AverageMeter('Loss', ':2.4f') |
|
|
iou_meter = AverageMeter('IoU', ':2.2f') |
|
|
pr_meter = AverageMeter('Prec@50', ':2.2f') |
|
|
progress = ProgressMeter( |
|
|
len(train_loader), |
|
|
[batch_time, data_time, lr, loss_meter, iou_meter, pr_meter], |
|
|
prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs)) |
|
|
|
|
|
model.train() |
|
|
time.sleep(2) |
|
|
end = time.time() |
|
|
|
|
|
|
|
|
|
|
|
for i, (image, text, target, l_mask, params) in enumerate(train_loader): |
|
|
data_time.update(time.time() - end) |
|
|
|
|
|
try: |
|
|
dist.barrier() |
|
|
except: |
|
|
logger.error(f"Barrier failed at iteration {i}, rank {dist.get_rank()}") |
|
|
continue |
|
|
|
|
|
image = image.cuda(non_blocking=True) |
|
|
text = text.cuda(non_blocking=True) |
|
|
target = target.cuda(non_blocking=True) |
|
|
l_mask = l_mask.cuda(non_blocking=True) |
|
|
hp_emb = params['hardpos_emb'].cuda(non_blocking=True) |
|
|
source_type = params['source_type'] |
|
|
|
|
|
|
|
|
orig_sent = params['sent'] |
|
|
orig_hardpos = params['hardpos'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text = text.squeeze(1) |
|
|
l_mask = l_mask.squeeze(1) |
|
|
|
|
|
|
|
|
with amp.autocast(): |
|
|
pred, target, loss = \ |
|
|
model(image, text, l_mask, mask=target, hp_bert_embs=hp_emb, source_type=source_type) |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
iou, pr5 = trainMetricGPU(pred, target, 0.35) |
|
|
dist.all_reduce(loss.detach()) |
|
|
dist.all_reduce(iou) |
|
|
dist.all_reduce(pr5) |
|
|
loss = loss / dist.get_world_size() |
|
|
iou = iou / dist.get_world_size() |
|
|
pr5 = pr5 / dist.get_world_size() |
|
|
|
|
|
del pred, target, text, l_mask, hp_emb |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
loss_meter.update(loss.item(), image.size(0)) |
|
|
iou_meter.update(iou.item(), image.size(0)) |
|
|
pr_meter.update(pr5.item(), image.size(0)) |
|
|
lr.update(optimizer.param_groups[0]["lr"]) |
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if (i + 1) % args.print_freq == 0: |
|
|
progress.display(i + 1) |
|
|
if dist.get_rank() in [-1, 0]: |
|
|
wandb.log( |
|
|
{ |
|
|
"time/batch": batch_time.val, |
|
|
"time/data": data_time.val, |
|
|
"training/lr": lr.val, |
|
|
"training/loss": loss_meter.val, |
|
|
"training/iou": iou_meter.val, |
|
|
"training/prec@50": pr_meter.val, |
|
|
}, |
|
|
step=epoch * len(train_loader) + (i + 1)) |
|
|
|
|
|
|
|
|
if i % 10 == 0: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(val_loader, model, epoch, args): |
|
|
iou_list = [] |
|
|
I_sum = 0 |
|
|
U_sum = 0 |
|
|
mean_acc = [] |
|
|
|
|
|
model.eval() |
|
|
time.sleep(2) |
|
|
|
|
|
for idx, (imgs, text, masks, l_mask, source_type) in enumerate(val_loader): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
imgs = imgs.cuda(non_blocking=True) |
|
|
text = text.cuda(non_blocking=True) |
|
|
l_mask = l_mask.cuda(non_blocking=True) |
|
|
|
|
|
text = text.squeeze(1) |
|
|
l_mask = l_mask.squeeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with amp.autocast(): |
|
|
preds, maps = model(imgs, text, l_mask) |
|
|
preds = torch.sigmoid(preds) |
|
|
|
|
|
for pred, mask, stype in zip(preds, masks, source_type): |
|
|
|
|
|
pred = pred.cpu().numpy() |
|
|
mask = mask.cpu().numpy() |
|
|
pred = np.array(pred > 0.5) |
|
|
|
|
|
if stype == 'zero': |
|
|
incorrect_num = np.sum(pred) |
|
|
acc = 1 if incorrect_num == 0 else 0 |
|
|
mean_acc.append(acc) |
|
|
else : |
|
|
|
|
|
inter_sum = np.sum(np.logical_and(pred, mask)) |
|
|
union_sum = np.sum(np.logical_or(pred, mask)) |
|
|
|
|
|
iou = inter_sum / (union_sum + 1e-6) |
|
|
iou_list.append(iou) |
|
|
I_sum += inter_sum |
|
|
U_sum += union_sum |
|
|
|
|
|
|
|
|
iou_list = torch.tensor(iou_list, device=imgs.device)\ |
|
|
|
|
|
I_sum = torch.tensor([I_sum], device=imgs.device) |
|
|
U_sum = torch.tensor([U_sum], device=imgs.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gathered_iou = concat_all_gather_varsize(iou_list) |
|
|
gathered_I = concat_all_gather_varsize(I_sum) |
|
|
gathered_U = concat_all_gather_varsize(U_sum) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gathered_I_sum = gathered_I.sum().item() |
|
|
gathered_U_sum = gathered_U.sum().item() |
|
|
|
|
|
iou = gathered_iou.mean().item() |
|
|
oIoU = gathered_I_sum / (gathered_U_sum + 1e-6) |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
prec_list = [] |
|
|
for thres in torch.arange(0.5, 1.0, 0.1): |
|
|
tmp = (gathered_iou > thres).float().mean() |
|
|
prec_list.append(tmp) |
|
|
|
|
|
prec = {} |
|
|
temp = ' ' |
|
|
for i, thres in enumerate(range(5, 10)): |
|
|
key = 'Pr@{}'.format(thres * 10) |
|
|
value = prec_list[i].item() |
|
|
prec[key] = value |
|
|
temp += "{}: {:.2f} ".format(key, 100. * value) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
head = 'Evaluation: Epoch=[{}/{}] mIoU={:.2f} oIoU={:.2f}'.format( |
|
|
epoch, args.epochs, 100. * iou, 100.*(oIoU)) |
|
|
if mean_acc: |
|
|
mean_acc = np.mean(mean_acc) |
|
|
head += ' Acc={:.2f}'.format(100. * mean_acc) |
|
|
else: |
|
|
mean_acc = 0 |
|
|
logger.info(head + temp) |
|
|
|
|
|
|
|
|
return iou, oIoU, prec, mean_acc |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference(test_loader, model, args): |
|
|
iou_list = [] |
|
|
I_sum = 0 |
|
|
U_sum = 0 |
|
|
mean_acc = [] |
|
|
|
|
|
tbar = tqdm(test_loader, desc='Inference:', ncols=100) |
|
|
model.eval() |
|
|
time.sleep(2) |
|
|
|
|
|
for ori_img, img, texts, mask, l_masks, seg_id, sents, source_type in tbar: |
|
|
img = img.cuda(non_blocking=True) |
|
|
mask = mask.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
for text, l_mask, sent in zip(texts, l_masks, sents): |
|
|
text = text.cuda(non_blocking=True) |
|
|
l_mask = l_mask.cuda(non_blocking=True) |
|
|
|
|
|
text = text.squeeze(1) |
|
|
l_mask = l_mask.squeeze(1) |
|
|
|
|
|
with amp.autocast(): |
|
|
pred, maps = model(img, text, l_mask) |
|
|
pred = torch.sigmoid(pred) |
|
|
if pred.shape[-2:] != ori_img.shape[:-1]: |
|
|
|
|
|
pred = F.interpolate(pred, size=ori_img.shape[1:-1], mode='bicubic', align_corners=True) |
|
|
|
|
|
|
|
|
pred = pred.cpu().numpy() |
|
|
pred = np.array(pred > 0.35) |
|
|
|
|
|
if source_type == 'zero': |
|
|
incorrect_num = np.sum(pred) |
|
|
acc = 1 if incorrect_num == 0 else 0 |
|
|
mean_acc.append(acc) |
|
|
else: |
|
|
inter_sum = np.sum(np.logical_and(pred, mask)) |
|
|
union_sum = np.sum(np.logical_or(pred, mask)) |
|
|
|
|
|
if union_sum == 0 : |
|
|
iou = 0.0 |
|
|
else : |
|
|
iou = inter_sum*1.0 / union_sum |
|
|
|
|
|
iou_list.append(iou) |
|
|
I_sum += inter_sum |
|
|
U_sum += union_sum |
|
|
|
|
|
logger.info('=> Metric Calculation <=') |
|
|
|
|
|
iou_list = np.stack(iou_list) |
|
|
iou_list = torch.from_numpy(iou_list).to(img.device) |
|
|
|
|
|
overall_IoU = I_sum / U_sum |
|
|
|
|
|
prec_list = [] |
|
|
for thres in torch.arange(0.5, 1.0, 0.1): |
|
|
tmp = (iou_list > thres).float().mean() |
|
|
prec_list.append(tmp) |
|
|
iou = iou_list.mean() |
|
|
prec = {} |
|
|
for i, thres in enumerate(range(5, 10)): |
|
|
key = 'Pr@{}'.format(thres*10) |
|
|
value = prec_list[i].item() |
|
|
prec[key] = value |
|
|
logger.info('oIoU={:.2f}'.format(100.*(I_sum/U_sum))) |
|
|
logger.info('mIoU={:.2f}'.format(100.*iou.item())) |
|
|
|
|
|
if mean_acc: |
|
|
|
|
|
mean_acc = np.mean(mean_acc) |
|
|
logger.info('Acc={:.2f}'.format(100. * mean_acc)) |
|
|
else: |
|
|
mean_acc = 0 |
|
|
|
|
|
for k, v in prec.items(): |
|
|
logger.info('{}: {:.2f}.'.format(k, 100.*v)) |
|
|
|
|
|
return iou.item(), overall_IoU, prec, mean_acc |
|
|
|
|
|
|
|
|
|