MRaCL / CGFormer /engine /engine_refzom_sbert.py
dianecy's picture
Upload folder using huggingface_hub
ea1014e verified
import os
import time
from tqdm import tqdm
import cv2
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.nn.functional as F
import wandb
from PIL import Image
from loguru import logger
from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather, concat_all_gather_varsize, trainMetricGPU)
def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
batch_time = AverageMeter('Batch', ':2.2f')
data_time = AverageMeter('Data', ':2.2f')
lr = AverageMeter('Lr', ':1.6f')
loss_meter = AverageMeter('Loss', ':2.4f')
iou_meter = AverageMeter('IoU', ':2.2f')
pr_meter = AverageMeter('Prec@50', ':2.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
model.train()
time.sleep(2)
end = time.time()
# size_list = [320, 352, 384, 416, 448, 480, 512]
for i, (image, text, target, l_mask, params) in enumerate(train_loader):
data_time.update(time.time() - end)
# data
try:
dist.barrier()
except:
logger.error(f"Barrier failed at iteration {i}, rank {dist.get_rank()}")
continue
image = image.cuda(non_blocking=True)
text = text.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
l_mask = l_mask.cuda(non_blocking=True)
hp_emb = params['hardpos_emb'].cuda(non_blocking=True)
source_type = params['source_type']
# for sanity check
orig_sent = params['sent']
orig_hardpos = params['hardpos']
# # multi-scale training
# image = F.interpolate(image, size=(new_size, new_size), mode='bilinear', align_corners=True)
text = text.squeeze(1)
l_mask = l_mask.squeeze(1)
# forward
with amp.autocast():
pred, target, loss = \
model(image, text, l_mask, mask=target, hp_bert_embs=hp_emb, source_type=source_type)
dist.barrier()
# metric
iou, pr5 = trainMetricGPU(pred, target, 0.35)
dist.all_reduce(loss.detach())
dist.all_reduce(iou)
dist.all_reduce(pr5)
loss = loss / dist.get_world_size()
iou = iou / dist.get_world_size()
pr5 = pr5 / dist.get_world_size()
del pred, target, text, l_mask, hp_emb
#delete all opts and backptop
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
loss_meter.update(loss.item(), image.size(0))
iou_meter.update(iou.item(), image.size(0))
pr_meter.update(pr5.item(), image.size(0))
lr.update(optimizer.param_groups[0]["lr"])
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % args.print_freq == 0:
progress.display(i + 1)
if dist.get_rank() in [-1, 0]:
wandb.log(
{
"time/batch": batch_time.val,
"time/data": data_time.val,
"training/lr": lr.val,
"training/loss": loss_meter.val,
"training/iou": iou_meter.val,
"training/prec@50": pr_meter.val,
},
step=epoch * len(train_loader) + (i + 1))
# flush every 10 steps
if i % 10 == 0:
torch.cuda.empty_cache()
@torch.no_grad()
def validate(val_loader, model, epoch, args):
iou_list = []
I_sum = 0
U_sum = 0
mean_acc = []
model.eval()
time.sleep(2)
for idx, (imgs, text, masks, l_mask, source_type) in enumerate(val_loader):
# data
# imgs = torch.stack(imgs).cuda(non_blocking=True)
# text = torch.stack(text).cuda(non_blocking=True)
# l_mask = torch.stack(l_mask).cuda(non_blocking=True)
imgs = imgs.cuda(non_blocking=True)
text = text.cuda(non_blocking=True)
l_mask = l_mask.cuda(non_blocking=True)
text = text.squeeze(1)
l_mask = l_mask.squeeze(1)
# print(imgs.shape, text.shape, l_mask.shape)
# print(source_type)
# inference
with amp.autocast(): # does inference need fp16?
preds, maps = model(imgs, text, l_mask)
preds = torch.sigmoid(preds)
# process one batch
for pred, mask, stype in zip(preds, masks, source_type):
# iou
pred = pred.cpu().numpy()
mask = mask.cpu().numpy()
pred = np.array(pred > 0.5)
if stype == 'zero': # Handle 'zero' source_type differently
incorrect_num = np.sum(pred)
acc = 1 if incorrect_num == 0 else 0
mean_acc.append(acc)
else :
# IoU calculation
inter_sum = np.sum(np.logical_and(pred, mask))
union_sum = np.sum(np.logical_or(pred, mask))
iou = inter_sum / (union_sum + 1e-6)
iou_list.append(iou)
I_sum += inter_sum
U_sum += union_sum
iou_list = torch.tensor(iou_list, device=imgs.device)\
I_sum = torch.tensor([I_sum], device=imgs.device)
U_sum = torch.tensor([U_sum], device=imgs.device)
# print("Before ioi list concat and gather ", iou_list.shape)
# print("Before Isum, Usum concat and gather", I_sum.shape, U_sum.shape)
gathered_iou = concat_all_gather_varsize(iou_list)
gathered_I = concat_all_gather_varsize(I_sum)
gathered_U = concat_all_gather_varsize(U_sum)
# print("Before I and U concat and gather ", gathered_I.shape, gathered_U.shape)
# print("After ioi list concat and gather ", gathered_iou.shape)
gathered_I_sum = gathered_I.sum().item()
gathered_U_sum = gathered_U.sum().item()
iou = gathered_iou.mean().item()
oIoU = gathered_I_sum / (gathered_U_sum + 1e-6)
# print("iou:", iou, "oIoU:", oIoU)
torch.cuda.empty_cache()
prec_list = []
for thres in torch.arange(0.5, 1.0, 0.1):
tmp = (gathered_iou > thres).float().mean()
prec_list.append(tmp)
prec = {}
temp = ' '
for i, thres in enumerate(range(5, 10)):
key = 'Pr@{}'.format(thres * 10)
value = prec_list[i].item()
prec[key] = value
temp += "{}: {:.2f} ".format(key, 100. * value)
dist.barrier()
if dist.get_rank() == 0:
head = 'Evaluation: Epoch=[{}/{}] mIoU={:.2f} oIoU={:.2f}'.format(
epoch, args.epochs, 100. * iou, 100.*(oIoU))
if mean_acc:
mean_acc = np.mean(mean_acc)
head += ' Acc={:.2f}'.format(100. * mean_acc)
else:
mean_acc = 0
logger.info(head + temp)
# print(head + temp)
return iou, oIoU, prec, mean_acc
@torch.no_grad()
def inference(test_loader, model, args):
iou_list = []
I_sum = 0
U_sum = 0
mean_acc = []
tbar = tqdm(test_loader, desc='Inference:', ncols=100)
model.eval()
time.sleep(2)
for ori_img, img, texts, mask, l_masks, seg_id, sents, source_type in tbar:
img = img.cuda(non_blocking=True)
mask = mask.cpu().numpy()
# print(len(texts), source_type)
# for all sentences for each referrals
for text, l_mask, sent in zip(texts, l_masks, sents):
text = text.cuda(non_blocking=True)
l_mask = l_mask.cuda(non_blocking=True)
text = text.squeeze(1)
l_mask = l_mask.squeeze(1)
with amp.autocast():
pred, maps = model(img, text, l_mask)
pred = torch.sigmoid(pred)
if pred.shape[-2:] != ori_img.shape[:-1]:
#print(f"before** {pred.shape}, {ori_img.shape}, {mask.shape}")
pred = F.interpolate(pred, size=ori_img.shape[1:-1], mode='bicubic', align_corners=True)
# # process one sentence
pred = pred.cpu().numpy()
pred = np.array(pred > 0.35)
if source_type == 'zero':
incorrect_num = np.sum(pred)
acc = 1 if incorrect_num == 0 else 0
mean_acc.append(acc)
else:
inter_sum = np.sum(np.logical_and(pred, mask)) # sum of intersection
union_sum = np.sum(np.logical_or(pred, mask)) # sum of union
if union_sum == 0 :
iou = 0.0
else :
iou = inter_sum*1.0 / union_sum
iou_list.append(iou)
I_sum += inter_sum
U_sum += union_sum
logger.info('=> Metric Calculation <=')
iou_list = np.stack(iou_list)
iou_list = torch.from_numpy(iou_list).to(img.device)
# print(iou_list.shape)
overall_IoU = I_sum / U_sum
prec_list = []
for thres in torch.arange(0.5, 1.0, 0.1):
tmp = (iou_list > thres).float().mean()
prec_list.append(tmp)
iou = iou_list.mean()
prec = {}
for i, thres in enumerate(range(5, 10)):
key = 'Pr@{}'.format(thres*10)
value = prec_list[i].item()
prec[key] = value
logger.info('oIoU={:.2f}'.format(100.*(I_sum/U_sum)))
logger.info('mIoU={:.2f}'.format(100.*iou.item()))
if mean_acc:
# Calculate accuracy for 'zero' cases
mean_acc = np.mean(mean_acc)
logger.info('Acc={:.2f}'.format(100. * mean_acc))
else:
mean_acc = 0
for k, v in prec.items():
logger.info('{}: {:.2f}.'.format(k, 100.*v))
return iou.item(), overall_IoU, prec, mean_acc