| import datetime |
| import os |
| import time |
|
|
| import torch |
| import torch.utils.data |
| from torch import nn |
|
|
| from functools import reduce |
| import operator |
| from bert.modeling_bert import BertModel |
| import json |
| from lib import segmentation |
| import pdb |
| import transforms |
| from transforms import transform |
| from data.dataset_zom import Refzom_DistributedSampler,Referzom_Dataset |
| from data.dataset_zom_rev import Refzom_DistributedSampler, Referzom_Dataset_HP |
| from data.dataset_rev import ReferDataset_HP |
| import utils |
| import numpy as np |
| from torch.utils.tensorboard import SummaryWriter |
| import gc |
|
|
|
|
|
|
| def get_dataset(image_set, transform, args, eval_mode): |
| if args.dataset == 'ref-zom': |
| ds = Referzom_Dataset_HP(args, |
| split=image_set, |
| image_transforms=transform, |
| target_transforms=None, |
| eval_mode=eval_mode |
| ) |
| else: |
| ds = ReferDataset_HP(args, |
| split=image_set, |
| image_transforms=transform, |
| target_transforms=None, |
| eval_mode=eval_mode |
| ) |
| num_classes = 2 |
|
|
| return ds, num_classes |
|
|
|
|
|
|
| def computeIoU(pred_seg, gd_seg): |
| I = np.sum(np.logical_and(pred_seg, gd_seg)) |
| U = np.sum(np.logical_or(pred_seg, gd_seg)) |
|
|
| return I, U |
|
|
|
|
|
|
| def get_transform(args): |
| transform = [transforms.Resize(args.img_size, args.img_size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ] |
|
|
| return transforms.Compose(transform) |
|
|
|
|
| def criterion(input, target): |
| weight = torch.FloatTensor([0.9, 1.1]).cuda() |
| return nn.functional.cross_entropy(input, target, weight=weight) |
|
|
|
|
|
|
| def return_mask(emb_distance, verb_mask=None): |
| B_, B_ = emb_distance.shape |
| positive_mask = torch.zeros_like(emb_distance) |
| positive_mask.fill_diagonal_(1) |
| |
| if B_ < len(verb_mask): |
| |
| for i in range(B_ // 2): |
| positive_mask[2 * i, 2 * i + 1] = 1 |
| positive_mask[2 * i + 1, 2 * i] = 1 |
| else: |
| |
| i = 0 |
| while i < B_: |
| if verb_mask[i] == 1: |
| positive_mask[i, i + 1] = 1 |
| positive_mask[i + 1, i] = 1 |
| i += 2 |
| else: |
| i += 1 |
| negative_mask = torch.ones_like(emb_distance) - positive_mask |
| return positive_mask, negative_mask |
|
|
|
|
| def UniAngularContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
| _, C, H, W = total_fq.shape |
| |
| if verbonly : |
| B = total_fq[verb_mask].shape[0] |
| emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) |
| assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." |
| else : |
| emb = torch.mean(total_fq, dim=-1) |
|
|
| B_ = emb.shape[0] |
| emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
| emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
| sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
| sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
| sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
| |
| positive_mask, negative_mask = return_mask(sim_matrix, verb_mask) |
| if len(positive_mask) > 0 : |
| sim_matrix_with_margin = sim_matrix.clone() |
| sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958) |
|
|
| logits = sim_matrix_with_margin / tau |
| exp_logits = torch.exp(logits) |
| pos_exp_logits = exp_logits * positive_mask.long() |
| pos_exp_logits = pos_exp_logits.sum(dim=-1) |
|
|
| |
| total_exp_logits = exp_logits.sum(dim=-1) |
| positive_loss = -torch.log(pos_exp_logits / total_exp_logits) |
| angular_loss = positive_loss.mean() |
|
|
| return angular_loss |
| else : |
| return torch.tensor(0.0, device=total_fq.device) |
|
|
|
|
| |
| def UniAngularLogitContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
| epsilon = 1e-10 |
| _, C, H, W = total_fq.shape |
|
|
| |
| if verbonly : |
| B = total_fq[verb_mask].shape[0] |
| emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) |
| assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." |
| else : |
| emb = torch.mean(total_fq, dim=-1) |
|
|
| B_ = emb.shape[0] |
| emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
| emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
|
| sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
| sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
| sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
|
|
| margin_in_radians = m / 57.2958 |
| theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) |
| positive_mask, negative_mask = return_mask(sim_matrix, verb_mask) |
|
|
| theta_with_margin = theta_matrix.clone() |
| theta_with_margin[positive_mask.bool()] -= margin_in_radians |
|
|
| logits = theta_with_margin / tau |
|
|
| |
| exp_logits = torch.exp(logits) |
| |
| pos_exp_logits = exp_logits * positive_mask |
| pos_exp_logits = pos_exp_logits.sum(dim=-1) |
|
|
| |
| |
| total_exp_logits = exp_logits.sum(dim=-1) |
|
|
| |
| |
|
|
| |
| loss = -torch.log(pos_exp_logits / total_exp_logits) |
| angular_loss = loss.mean() |
|
|
| return angular_loss |
|
|
|
|
|
|
| def evaluate(model, data_loader, bert_model): |
| model.eval() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| header = 'Test:' |
| total_its = 0 |
| acc_ious = 0 |
|
|
| |
| cum_I, cum_U = 0, 0 |
| eval_seg_iou_list = [.5, .6, .7, .8, .9] |
| seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) |
| seg_total = 0 |
| mean_IoU = [] |
| mean_acc = [] |
| with torch.no_grad(): |
| for data in metric_logger.log_every(data_loader, 100, header): |
| total_its += 1 |
| image, target, source_type, sentences, sentences1, attentions = data |
| image, sentences, sentences1, attentions = image.cuda(non_blocking=True), \ |
| sentences.cuda(non_blocking=True), \ |
| sentences1.cuda(non_blocking=True), \ |
| attentions.cuda(non_blocking=True) |
| sentences = sentences.squeeze(1) |
| sentences1 = sentences1.squeeze(1) |
| attentions = attentions.squeeze(1) |
| target = target.data.numpy() |
|
|
| for j in range(sentences.size(-1)): |
|
|
| last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] |
| embedding = last_hidden_states.permute(0, 2, 1) |
| embedding1 = embedding |
| loss_contra, loss_lansim, output = model(image, embedding, embedding1, l_mask=attentions[:, :, j].unsqueeze(-1), training_flag=True) |
|
|
| output_mask = output.argmax(1).cpu().data.numpy() |
|
|
| if source_type[0] == 'zero': |
| incorrect_num = np.sum(output_mask) |
| if incorrect_num == 0: |
| acc = 1 |
| else: |
| acc = 0 |
| mean_acc.append(acc) |
| else: |
| I, U = computeIoU(output_mask, target) |
| if U == 0: |
| this_iou = 0.0 |
| else: |
| this_iou = I*1.0/U |
| mean_IoU.append(this_iou) |
| cum_I += I |
| cum_U += U |
|
|
| for n_eval_iou in range(len(eval_seg_iou_list)): |
| eval_seg_iou = eval_seg_iou_list[n_eval_iou] |
| seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) |
|
|
| seg_total += 1 |
|
|
|
|
| mIoU = np.mean(mean_IoU) |
| mean_acc = np.mean(mean_acc) |
| precs = [] |
| print('Final results:') |
| results_str = '' |
| for n_eval_iou in range(len(eval_seg_iou_list)): |
| results_str += ' precision@%s = %.2f\n' % \ |
| (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) |
| precs.append(seg_correct[n_eval_iou] * 100. / seg_total) |
|
|
| results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) |
| results_str += ' mean IoU = %.2f\n' % (mIoU * 100.) |
| print(results_str) |
| if args.dataset == 'ref-zom': |
| print('Mean accuracy for one-to-zero sample is %.2f\n' % (mean_acc*100)) |
|
|
| return mIoU, 100 * cum_I / cum_U, precs |
|
|
|
|
| def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, |
| iterations, bert_model, metric_learning, args): |
| model.train() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) |
| header = 'Epoch: [{}]'.format(epoch) |
| train_loss = 0 |
| total_its = 0 |
| mlw = args.metric_loss_weight |
| metric_mode = args.metric_mode |
|
|
|
|
| for data in metric_logger.log_every(data_loader, print_freq, header): |
| total_its += 1 |
| image, target, source_type, sentences, sentences_masked, attentions, pos_sent, pos_attn_mask, pos_type = data |
| source_type = np.array(source_type) |
| pos_type = np.array(pos_type) |
| target_flag = torch.tensor(np.where(source_type == 'zero', 0, 1)) |
| if args.addzero : |
| hardpos_flag = torch.tensor(np.where(pos_type == 'hardpos', 1, 0)) |
| else : |
| |
| hardpos_flag = torch.tensor(np.where((source_type != 'zero') & (pos_type == 'hardpos'), 1, 0)) |
|
|
| sentences = sentences.squeeze(1) |
| sentences_masked = sentences_masked.squeeze(1) |
| attentions = attentions.squeeze(1) |
| pos_sent = pos_sent.squeeze(1) |
| pos_attn_mask = pos_attn_mask.squeeze(1) |
|
|
| |
| verb_masks = [] |
| cl_masks = [] |
| images = [] |
| targets = [] |
| sentences_ = [] |
| sentences_masked_ = [] |
| attentions_ = [] |
| |
| |
| for idx in range(len(image)) : |
| |
| sentences_.append(sentences[idx]) |
| sentences_masked_.append(sentences_masked[idx]) |
| images.append(image[idx]) |
| targets.append(target[idx]) |
| attentions_.append(attentions[idx]) |
|
|
| if hardpos_flag[idx] : |
| verb_masks.extend([1, 1]) |
| cl_masks.extend([1, 0]) |
| sentences_.append(pos_sent[idx]) |
| sentences_masked_.append(sentences_masked[idx]) |
| images.append(image[idx]) |
| targets.append(target[idx]) |
| attentions_.append(pos_attn_mask[idx]) |
| else: |
| verb_masks.append(0) |
| cl_masks.append(1) |
|
|
| image, target, sentences, sentences_masked, attentions, verb_masks, cl_masks = \ |
| torch.stack(images).cuda(non_blocking=True),\ |
| torch.stack(targets).cuda(non_blocking=True),\ |
| torch.stack(sentences_).cuda(non_blocking=True),\ |
| torch.stack(sentences_masked_).cuda(non_blocking=True),\ |
| torch.stack(attentions_).cuda(non_blocking=True),\ |
| torch.tensor(verb_masks, dtype=torch.bool, device='cuda'),\ |
| torch.tensor(cl_masks, dtype=torch.bool, device='cuda') |
|
|
| |
| last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] |
| last_hidden_states1 = bert_model(sentences_masked, attention_mask=attentions)[0] |
| embedding = last_hidden_states.permute(0, 2, 1) |
| embedding1 = last_hidden_states1.permute(0, 2, 1) |
| attentions = attentions.unsqueeze(dim=-1) |
| |
| |
| |
| loss_contra, loss_lansim, output, metric_tensors = model(image, embedding, embedding1, l_mask=attentions, cl_masks=cl_masks, target_flag=target_flag, training_flag=True) |
|
|
| loss_seg = criterion(output[cl_masks], target[cl_masks]) |
| |
| if metric_learning and sum(hardpos_flag) > 0 : |
| metric_loss = UniAngularLogitContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) |
| total_weight = 1 + 0.01 + 0.01 + mlw |
| loss = (loss_seg + loss_lansim * 0.01 + loss_contra * 0.01 + metric_loss * mlw) / total_weight |
| else : |
| loss = loss_seg + loss_lansim * 0.01 + loss_contra * 0.01 |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| lr_scheduler.step() |
|
|
| torch.cuda.synchronize() |
| train_loss += loss.item() |
| iterations += 1 |
| metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) |
| metric_logger.update(loss_seg=loss_seg.item(), lr=optimizer.param_groups[0]["lr"]) |
| metric_logger.update(loss_lansim=loss_lansim.item(), lr=optimizer.param_groups[0]["lr"]) |
| metric_logger.update(loss_contra=loss_contra.item(), lr=optimizer.param_groups[0]["lr"]) |
|
|
| del image, target, sentences, sentences_masked, attentions, verb_masks, cl_masks, loss, output, metric_tensors, data |
|
|
| if bert_model is not None: |
| del last_hidden_states, embedding, last_hidden_states1, embedding1 |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| loss_log = { |
| 'loss': metric_logger.meters['loss'].global_avg |
| } |
| return iterations, loss_log |
|
|
|
|
| def main(args): |
| writer = SummaryWriter('./experiments/{}/{}'.format("_".join([args.dataset, args.splitBy]), args.model_id)) |
|
|
| dataset, num_classes = get_dataset("train", |
| transform(args=args), |
| args=args, |
| eval_mode=False) |
| dataset_test, _ = get_dataset(args.split, |
| get_transform(args=args), |
| args=args, eval_mode=True) |
|
|
| |
| print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") |
| num_tasks = utils.get_world_size() |
| global_rank = utils.get_rank() |
| if args.dataset == 'ref-zom': |
| train_sampler = Refzom_DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, |
| shuffle=True) |
| else: |
| train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, |
| shuffle=True) |
| test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
| |
| data_loader = torch.utils.data.DataLoader( |
| dataset, batch_size=args.batch_size, |
| sampler=train_sampler, num_workers=args.workers, pin_memory=args.pin_mem, drop_last=True) |
|
|
| data_loader_test = torch.utils.data.DataLoader( |
| dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) |
|
|
| |
| print(args.model) |
| model = segmentation.__dict__[args.model](pretrained=args.pretrained_backbone, args=args) |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| model.cuda() |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) |
| single_model = model.module |
|
|
| model_class = BertModel |
| bert_model = model_class.from_pretrained(args.ck_bert) |
| bert_model.pooler = None |
| bert_model.cuda() |
| bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) |
| bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) |
| single_bert_model = bert_model.module |
|
|
|
|
| |
| if args.resume: |
| checkpoint = torch.load(args.resume, map_location='cpu') |
| single_model.load_state_dict(checkpoint['model']) |
| single_bert_model.load_state_dict(checkpoint['bert_model']) |
|
|
| |
| backbone_no_decay = list() |
| backbone_decay = list() |
| for name, m in single_model.backbone.named_parameters(): |
| if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: |
| backbone_no_decay.append(m) |
| else: |
| backbone_decay.append(m) |
|
|
| params_to_optimize = [ |
| {'params': backbone_no_decay, 'weight_decay': 0.0}, |
| {'params': backbone_decay}, |
| {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, |
| {"params": [p for p in single_model.contrastive.parameters() if p.requires_grad]}, |
| |
| {"params": reduce(operator.concat, |
| [[p for p in single_bert_model.encoder.layer[i].parameters() |
| if p.requires_grad] for i in range(10)])}, |
| ] |
| |
| |
| |
| |
| optimizer = torch.optim.AdamW(params_to_optimize, |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| amsgrad=args.amsgrad |
| ) |
|
|
| |
| lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, |
| lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) |
|
|
| |
| start_time = time.time() |
| iterations = 0 |
| best_oIoU = -0.1 |
|
|
| |
| if args.resume: |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| resume_epoch = checkpoint['epoch'] |
| else: |
| resume_epoch = -999 |
|
|
| |
| |
| for epoch in range(max(0, resume_epoch+1), args.epochs): |
| data_loader.sampler.set_epoch(epoch) |
| itrs_temp, loss_log = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, |
| iterations, bert_model, metric_learning=args.metric_learning, args=args) |
| mean_IoU, overall_IoU, precs = evaluate(model, data_loader_test, bert_model) |
|
|
| print('Average object IoU {}'.format(mean_IoU)) |
| print('Overall IoU {}'.format(overall_IoU)) |
|
|
|
|
| save_checkpoint = (best_oIoU < overall_IoU) |
| if save_checkpoint: |
| print('Better epoch: {}\n'.format(epoch)) |
| if single_bert_model is not None: |
| dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
| 'lr_scheduler': lr_scheduler.state_dict()} |
| else: |
| dict_to_save = {'model': single_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
| 'lr_scheduler': lr_scheduler.state_dict()} |
|
|
| utils.save_on_master(dict_to_save, os.path.join(args.output_dir, |
| 'model_best_{}.pth'.format(args.model_id))) |
| best_oIoU = overall_IoU |
| print('The best_performance is {}'.format(best_oIoU)) |
|
|
|
|
| if utils.is_main_process(): |
| writer.add_scalar('val/mIoU', mean_IoU, epoch) |
| writer.add_scalar('val/oIoU', overall_IoU, epoch) |
| writer.add_scalar('val/Prec/50', precs[0], epoch) |
| writer.add_scalar('val/Prec/60', precs[1], epoch) |
| writer.add_scalar('val/Prec/70', precs[2], epoch) |
| writer.add_scalar('val/Prec/80', precs[3], epoch) |
| writer.add_scalar('val/Prec/90', precs[4], epoch) |
| writer.add_scalar('train/loss', loss_log['loss'], epoch) |
|
|
| writer.flush() |
| |
|
|
| |
| print('The final_best_performance is {}'.format(best_oIoU)) |
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| print('Training time {}'.format(total_time_str)) |
|
|
|
|
| if __name__ == "__main__": |
| from args import get_parser |
| parser = get_parser() |
| args = parser.parse_args() |
| |
| |
| if "LOCAL_RANK" in os.environ: |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| else: |
| local_rank = 0 |
|
|
| print(f"Local Rank: {local_rank}, World Size: {os.environ.get('WORLD_SIZE', '1')}") |
|
|
|
|
| utils.init_distributed_mode(args) |
| print('Image size: {}'.format(str(args.img_size))) |
| print('Metric Learning Ops') |
| print('metric learning flag : ', args.metric_learning) |
| print('metric loss weight : ', args.metric_loss_weight) |
| print('metric mode and hardpos selection : ', args.metric_mode, args.hp_selection) |
| print('margin value : ', args.margin_value) |
| print('temperature : ', args.temperature) |
| main(args) |
| |