import os import time from tqdm import tqdm import cv2 import numpy as np import torch import torch.cuda.amp as amp import torch.distributed as dist import torch.nn.functional as F import wandb from PIL import Image from loguru import logger from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather, concat_all_gather_varsize, trainMetricGPU) def train(train_loader, model, optimizer, scheduler, scaler, epoch, args): batch_time = AverageMeter('Batch', ':2.2f') data_time = AverageMeter('Data', ':2.2f') lr = AverageMeter('Lr', ':1.6f') loss_meter = AverageMeter('Loss', ':2.4f') iou_meter = AverageMeter('IoU', ':2.2f') pr_meter = AverageMeter('Prec@50', ':2.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter], prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs)) model.train() time.sleep(2) end = time.time() # size_list = [320, 352, 384, 416, 448, 480, 512] for i, (image, text, target, l_mask, params) in enumerate(train_loader): data_time.update(time.time() - end) # data try: dist.barrier() except: logger.error(f"Barrier failed at iteration {i}, rank {dist.get_rank()}") continue image = image.cuda(non_blocking=True) text = text.cuda(non_blocking=True) target = target.cuda(non_blocking=True) l_mask = l_mask.cuda(non_blocking=True) hp_emb = params['hardpos_emb'].cuda(non_blocking=True) source_type = params['source_type'] # for sanity check orig_sent = params['sent'] orig_hardpos = params['hardpos'] # # multi-scale training # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear', align_corners=True) text = text.squeeze(1) l_mask = l_mask.squeeze(1) # forward with amp.autocast(): pred, target, loss = \ model(image, text, l_mask, mask=target, hp_bert_embs=hp_emb, source_type=source_type) dist.barrier() # metric iou, pr5 = trainMetricGPU(pred, target, 0.35) dist.all_reduce(loss.detach()) dist.all_reduce(iou) dist.all_reduce(pr5) loss = loss / dist.get_world_size() iou = iou / dist.get_world_size() pr5 = pr5 / dist.get_world_size() del pred, target, text, l_mask, hp_emb #delete all opts and backptop optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() loss_meter.update(loss.item(), image.size(0)) iou_meter.update(iou.item(), image.size(0)) pr_meter.update(pr5.item(), image.size(0)) lr.update(optimizer.param_groups[0]["lr"]) batch_time.update(time.time() - end) end = time.time() if (i + 1) % args.print_freq == 0: progress.display(i + 1) if dist.get_rank() in [-1, 0]: wandb.log( { "time/batch": batch_time.val, "time/data": data_time.val, "training/lr": lr.val, "training/loss": loss_meter.val, "training/iou": iou_meter.val, "training/prec@50": pr_meter.val, }, step=epoch * len(train_loader) + (i + 1)) # flush every 10 steps if i % 10 == 0: torch.cuda.empty_cache() @torch.no_grad() def validate(val_loader, model, epoch, args): iou_list = [] I_sum = 0 U_sum = 0 mean_acc = [] model.eval() time.sleep(2) for idx, (imgs, text, masks, l_mask, source_type) in enumerate(val_loader): # data # imgs = torch.stack(imgs).cuda(non_blocking=True) # text = torch.stack(text).cuda(non_blocking=True) # l_mask = torch.stack(l_mask).cuda(non_blocking=True) imgs = imgs.cuda(non_blocking=True) text = text.cuda(non_blocking=True) l_mask = l_mask.cuda(non_blocking=True) text = text.squeeze(1) l_mask = l_mask.squeeze(1) # print(imgs.shape, text.shape, l_mask.shape) # print(source_type) # inference with amp.autocast(): # does inference need fp16? preds, maps = model(imgs, text, l_mask) preds = torch.sigmoid(preds) # process one batch for pred, mask, stype in zip(preds, masks, source_type): # iou pred = pred.cpu().numpy() mask = mask.cpu().numpy() pred = np.array(pred > 0.5) if stype == 'zero': # Handle 'zero' source_type differently incorrect_num = np.sum(pred) acc = 1 if incorrect_num == 0 else 0 mean_acc.append(acc) else : # IoU calculation inter_sum = np.sum(np.logical_and(pred, mask)) union_sum = np.sum(np.logical_or(pred, mask)) iou = inter_sum / (union_sum + 1e-6) iou_list.append(iou) I_sum += inter_sum U_sum += union_sum iou_list = torch.tensor(iou_list, device=imgs.device)\ I_sum = torch.tensor([I_sum], device=imgs.device) U_sum = torch.tensor([U_sum], device=imgs.device) # print("Before ioi list concat and gather ", iou_list.shape) # print("Before Isum, Usum concat and gather", I_sum.shape, U_sum.shape) gathered_iou = concat_all_gather_varsize(iou_list) gathered_I = concat_all_gather_varsize(I_sum) gathered_U = concat_all_gather_varsize(U_sum) # print("Before I and U concat and gather ", gathered_I.shape, gathered_U.shape) # print("After ioi list concat and gather ", gathered_iou.shape) gathered_I_sum = gathered_I.sum().item() gathered_U_sum = gathered_U.sum().item() iou = gathered_iou.mean().item() oIoU = gathered_I_sum / (gathered_U_sum + 1e-6) # print("iou:", iou, "oIoU:", oIoU) torch.cuda.empty_cache() prec_list = [] for thres in torch.arange(0.5, 1.0, 0.1): tmp = (gathered_iou > thres).float().mean() prec_list.append(tmp) prec = {} temp = ' ' for i, thres in enumerate(range(5, 10)): key = 'Pr@{}'.format(thres * 10) value = prec_list[i].item() prec[key] = value temp += "{}: {:.2f} ".format(key, 100. * value) dist.barrier() if dist.get_rank() == 0: head = 'Evaluation: Epoch=[{}/{}] mIoU={:.2f} oIoU={:.2f}'.format( epoch, args.epochs, 100. * iou, 100.*(oIoU)) if mean_acc: mean_acc = np.mean(mean_acc) head += ' Acc={:.2f}'.format(100. * mean_acc) else: mean_acc = 0 logger.info(head + temp) # print(head + temp) return iou, oIoU, prec, mean_acc @torch.no_grad() def inference(test_loader, model, args): iou_list = [] I_sum = 0 U_sum = 0 mean_acc = [] tbar = tqdm(test_loader, desc='Inference:', ncols=100) model.eval() time.sleep(2) for ori_img, img, texts, mask, l_masks, seg_id, sents, source_type in tbar: img = img.cuda(non_blocking=True) mask = mask.cpu().numpy() # print(len(texts), source_type) # for all sentences for each referrals for text, l_mask, sent in zip(texts, l_masks, sents): text = text.cuda(non_blocking=True) l_mask = l_mask.cuda(non_blocking=True) text = text.squeeze(1) l_mask = l_mask.squeeze(1) with amp.autocast(): pred, maps = model(img, text, l_mask) pred = torch.sigmoid(pred) if pred.shape[-2:] != ori_img.shape[:-1]: #print(f"before** {pred.shape}, {ori_img.shape}, {mask.shape}") pred = F.interpolate(pred, size=ori_img.shape[1:-1], mode='bicubic', align_corners=True) # # process one sentence pred = pred.cpu().numpy() pred = np.array(pred > 0.35) if source_type == 'zero': incorrect_num = np.sum(pred) acc = 1 if incorrect_num == 0 else 0 mean_acc.append(acc) else: inter_sum = np.sum(np.logical_and(pred, mask)) # sum of intersection union_sum = np.sum(np.logical_or(pred, mask)) # sum of union if union_sum == 0 : iou = 0.0 else : iou = inter_sum*1.0 / union_sum iou_list.append(iou) I_sum += inter_sum U_sum += union_sum logger.info('=> Metric Calculation <=') iou_list = np.stack(iou_list) iou_list = torch.from_numpy(iou_list).to(img.device) # print(iou_list.shape) overall_IoU = I_sum / U_sum prec_list = [] for thres in torch.arange(0.5, 1.0, 0.1): tmp = (iou_list > thres).float().mean() prec_list.append(tmp) iou = iou_list.mean() prec = {} for i, thres in enumerate(range(5, 10)): key = 'Pr@{}'.format(thres*10) value = prec_list[i].item() prec[key] = value logger.info('oIoU={:.2f}'.format(100.*(I_sum/U_sum))) logger.info('mIoU={:.2f}'.format(100.*iou.item())) if mean_acc: # Calculate accuracy for 'zero' cases mean_acc = np.mean(mean_acc) logger.info('Acc={:.2f}'.format(100. * mean_acc)) else: mean_acc = 0 for k, v in prec.items(): logger.info('{}: {:.2f}.'.format(k, 100.*v)) return iou.item(), overall_IoU, prec, mean_acc