| import datetime |
| import os |
| import time |
|
|
| import torch |
| import torch.utils.data |
| from torch import nn |
|
|
| from functools import reduce |
| import operator |
| from bert.multimodal_bert import MultiModalBert |
|
|
| import torchvision |
| from lib import multimodal_segmentation_ppm |
|
|
| import transforms as T |
| import utils |
| import numpy as np |
|
|
| import torch.nn.functional as F |
|
|
| import gc |
| from collections import OrderedDict |
|
|
| import torch.backends.cudnn as cudnn |
|
|
| |
| from modeling.MaskFormerModel import MaskFormerHead |
| from addict import Dict |
|
|
| from mask2former_utils.criterion import SetCriterion, Criterion |
| from mask2former_utils.matcher import HungarianMatcher |
| from bert.modeling_bert import BertLMPredictionHead, BertEncoder |
|
|
|
|
|
|
|
|
| class WrapperModel(nn.Module): |
| def __init__(self, image_model, language_model, classifier, args) : |
| super(WrapperModel, self).__init__() |
| self.image_model = image_model |
| self.language_model = language_model |
| self.classifier = classifier |
|
|
| self.lang_proj = nn.Linear(768,256) |
|
|
| config = Dict({ |
| "architectures": [ |
| "BertForMaskedLM" |
| ], |
| "attention_probs_dropout_prob": 0.1, |
| "gradient_checkpointing": False, |
| "hidden_act": "gelu", |
| "hidden_dropout_prob": 0.1, |
| "hidden_size": 512, |
| "initializer_range": 0.02, |
| "intermediate_size": 3072, |
| "layer_norm_eps": 1e-12, |
| |
| "model_type": "bert", |
| "num_attention_heads": 8, |
| "num_hidden_layers": 8, |
| "pad_token_id": 0, |
| "position_embedding_type": "absolute", |
| "transformers_version": "4.6.0.dev0", |
| "type_vocab_size": 2, |
| "use_cache": True, |
| "vocab_size": 30522 |
| }) |
| self.mlm_transformer = BertEncoder(config) |
|
|
| self.lang_proj = nn.Linear(768,256) |
| self.mlm_vis_proj = nn.Conv2d(1024,512,1) |
| self.mlm_lang_proj = nn.Linear(768,512) |
| |
| self.mlm_head = BertLMPredictionHead(config) |
|
|
| assert args.img_size % 4 == 0 |
| num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2 |
| print(num_img_tokens) |
| self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512) |
| self.mlm_modal_embeds = nn.Embedding(3, 512) |
|
|
| self.mlm_mask_embed = nn.Embedding(1, 512) |
| self.mlm_pos_mlp = nn.Sequential( |
| nn.Linear(2, 512), |
| nn.LayerNorm(512), |
| nn.Linear(512,512), |
| nn.GELU() |
| ) |
|
|
| def _get_binary_mask(self, target): |
| |
| y, x = target.size() |
| target_onehot = torch.zeros(self.num_classes + 1, y, x) |
| target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) |
| return target_onehot[1:] |
|
|
| def semantic_inference(self, mask_cls, mask_pred): |
| mask_cls = F.softmax(mask_cls, dim=1)[...,1:] |
| mask_pred = mask_pred.sigmoid() |
| semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) |
| return semseg |
|
|
| def forward(self, image, sentences, attentions, mlm_targets, mlm_masks, position): |
| input_shape = image.shape[-2:] |
| l_mask = attentions.unsqueeze(dim=-1) |
|
|
| i0, Wh, Ww = self.image_model.forward_stem(image) |
| l0, extended_attention_mask = self.language_model.forward_stem(mlm_targets.squeeze(1), attentions) |
|
|
| i1 = self.image_model.forward_stage1(i0, Wh, Ww) |
| l1 = self.language_model.forward_stage1(l0, extended_attention_mask) |
| i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) |
| l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) |
| i1 = i1_temp |
|
|
| i2 = self.image_model.forward_stage2(i1, Wh, Ww) |
| l2 = self.language_model.forward_stage2(l1, extended_attention_mask) |
| i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) |
| l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) |
| i2 = i2_temp |
|
|
| i3 = self.image_model.forward_stage3(i2, Wh, Ww) |
| l3 = self.language_model.forward_stage3(l2, extended_attention_mask) |
| i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) |
| l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) |
| i3 = i3_temp |
|
|
| i4 = self.image_model.forward_stage4(i3, Wh, Ww) |
| l4 = self.language_model.forward_stage4(l3, extended_attention_mask) |
| i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) |
| l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) |
| i4 = i4_temp |
|
|
| |
| |
| |
| outputs = {} |
| outputs['s1'] = i1_residual |
| outputs['s2'] = i2_residual |
| outputs['s3'] = i3_residual |
| outputs['s4'] = i4_residual |
|
|
| predictions, mask_features = self.classifier(outputs) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) |
| l1 = self.language_model.forward_stage1(l0, extended_attention_mask) |
| l2 = self.language_model.forward_stage2(l1, extended_attention_mask) |
| l3 = self.language_model.forward_stage3(l2, extended_attention_mask) |
| l4 = self.language_model.forward_stage4(l3, extended_attention_mask) |
|
|
|
|
| mlp_embed = self.mlm_pos_mlp(position) |
| |
|
|
| mlm_targets = torch.where( |
| mlm_masks > 0, |
| mlm_targets, |
| torch.ones_like(mlm_targets) * (-1) |
| ) |
|
|
| |
| vis_features = self.mlm_vis_proj(i4_residual).flatten(2).permute(0,2,1) |
| |
| lang_features = self.mlm_lang_proj(l4) |
| |
| |
| mm_features = torch.cat([lang_features, vis_features, mlp_embed.unsqueeze(1)], dim=1) |
| |
|
|
| |
| modal_embeds = torch.cat([self.mlm_modal_embeds.weight[0].unsqueeze(0).repeat(1, lang_features.shape[1], 1), self.mlm_modal_embeds.weight[1].unsqueeze(0).repeat(1, vis_features.shape[1], 1), self.mlm_modal_embeds.weight[2].unsqueeze(0).repeat(1,1,1)], dim=1) |
| |
|
|
| |
|
|
|
|
| |
| mixed_attention_mask = torch.cat([attentions.unsqueeze(-1), torch.ones(attentions.shape[0], vis_features.shape[1]+1, 1).to(attentions.device)], dim=1) |
| mixed_attention_mask = mixed_attention_mask.permute(0,2,1).unsqueeze(1) |
| mixed_attention_mask = (1-mixed_attention_mask)* -10000.0 |
| head_mask = [None] * 8 |
| |
| |
| |
| head_features = self.mlm_transformer(mm_features + self.mlm_pos_embeds.weight.unsqueeze(0) + modal_embeds, mixed_attention_mask, head_mask)[0] |
| |
| head_features = head_features[:, :20][attentions.bool()] |
| |
| |
| mlm_predictions = self.mlm_head(head_features) |
| mlm_predictions = mlm_predictions.reshape(-1, self.language_model.config.vocab_size) |
| mlm_targets = mlm_targets.squeeze(1)[attentions.bool()] |
| |
| |
| |
|
|
| return predictions, mask_features, self.lang_proj((l4_residual * l_mask).sum(1)/l_mask.sum(1)), mlm_predictions, mlm_targets |
| |
| def IoU(pred, gt): |
| |
| pred = (pred > 0.5) |
|
|
| intersection = torch.sum(torch.mul(pred, gt)) |
| union = torch.sum(torch.add(pred, gt)) - intersection |
|
|
| if intersection == 0 or union == 0: |
| iou = 0 |
| else: |
| iou = float(intersection) / float(union) |
|
|
| return iou, intersection, union |
|
|
| def get_dataset(image_set, transform, args): |
| from data.dataset_refer_bert_mlm import ReferDataset |
| ds = ReferDataset(args, |
| split=image_set, |
| image_transforms=transform, |
| target_transforms=None |
| ) |
| num_classes = 2 |
|
|
| return ds, num_classes |
|
|
|
|
|
|
| def get_transform(args): |
| transforms = [T.Resize(args.img_size, args.img_size), |
| T.ToTensor(), |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ] |
|
|
| return T.Compose(transforms) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def evaluate(model, data_loader): |
| model.eval() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| header = 'Test:' |
| total_its = 0 |
| acc_ious = 0 |
|
|
| |
| cum_I, cum_U = 0, 0 |
| eval_seg_iou_list = [.5, .6, .7, .8, .9] |
| seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) |
| seg_total = 0 |
| mean_IoU = [] |
|
|
| with torch.no_grad(): |
| for data in metric_logger.log_every(data_loader, 100, header): |
| total_its += 1 |
| |
| |
| |
| |
| |
|
|
| image, target, sentences, attentions, mlm_targets, mlm_masks, position = data |
| image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\ |
| target.cuda(non_blocking=True),\ |
| sentences.cuda(non_blocking=True),\ |
| attentions.cuda(non_blocking=True), \ |
| mlm_targets.cuda(non_blocking=True), \ |
| mlm_masks.cuda(non_blocking=True), \ |
| position.cuda(non_blocking=True) |
|
|
| sentences = sentences.squeeze(1) |
| attentions = attentions.squeeze(1) |
| |
| |
|
|
|
|
| output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position) |
| mask_cls_results = output["pred_logits"] |
| mask_pred_results = output["pred_masks"] |
|
|
| target_shape = target.shape[-2:] |
| mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) |
|
|
| pred_masks = model.module.semantic_inference(mask_cls_results, mask_pred_results) |
| output = pred_masks[0] |
|
|
|
|
| iou, I, U = IoU(output, target) |
| acc_ious += iou |
| mean_IoU.append(iou) |
| cum_I += I |
| cum_U += U |
| for n_eval_iou in range(len(eval_seg_iou_list)): |
| eval_seg_iou = eval_seg_iou_list[n_eval_iou] |
| seg_correct[n_eval_iou] += (iou >= eval_seg_iou) |
| seg_total += 1 |
| iou = acc_ious / total_its |
|
|
| mean_IoU = np.array(mean_IoU) |
| mIoU = np.mean(mean_IoU) |
| print('Final results:') |
| print('Mean IoU is %.2f\n' % (mIoU * 100.)) |
| results_str = '' |
| for n_eval_iou in range(len(eval_seg_iou_list)): |
| results_str += ' precision@%s = %.2f\n' % \ |
| (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) |
| results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) |
| print(results_str) |
|
|
| return 100 * iou, 100 * cum_I / cum_U |
|
|
|
|
| def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, |
| iterations, args): |
| model.train() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) |
| header = 'Epoch: [{}]'.format(epoch) |
| train_loss = 0 |
| total_its = 0 |
|
|
| for data in metric_logger.log_every(data_loader, print_freq, header): |
| total_its += 1 |
| |
| |
| |
| |
| |
| image, target, sentences, attentions, mlm_targets, mlm_masks, position = data |
| image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\ |
| target.cuda(non_blocking=True),\ |
| sentences.cuda(non_blocking=True),\ |
| attentions.cuda(non_blocking=True), \ |
| mlm_targets.cuda(non_blocking=True), \ |
| mlm_masks.cuda(non_blocking=True), \ |
| position.cuda(non_blocking=True) |
|
|
| sentences = sentences.squeeze(1) |
| attentions = attentions.squeeze(1) |
| |
|
|
| output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position) |
| |
| avg_lang_feature = torch.nn.functional.normalize(avg_lang_feature, dim=1) |
| |
| |
| |
| |
| |
| |
| |
|
|
| target_shape = target.shape[-2:] |
| output['pred_masks'] = F.interpolate(output['pred_masks'], size=target_shape, mode='bilinear', align_corners=True) |
|
|
| if "aux_outputs" in output: |
| for i, aux_outputs in enumerate(output["aux_outputs"]): |
| output['aux_outputs'][i]['pred_masks'] = F.interpolate(output['aux_outputs'][i]['pred_masks'], size=target_shape, mode='bilinear', align_corners=True) |
|
|
| |
| B, C, H, W = mask_features.shape |
|
|
| target_reshape = F.interpolate(target.unsqueeze(1).float(), size=mask_features.shape[-2:], mode='nearest').long() |
|
|
| target_reshape = target_reshape.repeat(1, mask_features.shape[1], 1, 1) |
| |
|
|
| |
| plic_lang_loss = 0.0 |
| plic_pos_loss = 0.0 |
| plic_neg_loss = 0.0 |
| for i in range(B): |
| if ((target_reshape[[i]] == 0).sum() != 0 and (target_reshape[[i]] == 1).sum() != 0): |
|
|
| avg_pos_feature = (mask_features[[i]] * target_reshape[[i]]).sum(-1).sum(-1) / target_reshape[[i]].sum(-1).sum(-1) |
| avg_neg_feature = (mask_features[[i]] * (1.0-target_reshape[[i]])).sum(-1).sum(-1) / (1.0-target_reshape[[i]]).sum(-1).sum(-1) |
| avg_pos_feature = torch.nn.functional.normalize(avg_pos_feature, dim=1) |
| avg_neg_feature = torch.nn.functional.normalize(avg_neg_feature, dim=1) |
|
|
| |
|
|
|
|
|
|
| pos_features = mask_features[[i]][target_reshape[[i]]==1].view(1, C, -1) |
| neg_features = mask_features[[i]][target_reshape[[i]]==0].view(1, C, -1) |
| |
| |
|
|
| pos_features = torch.nn.functional.normalize(pos_features, dim=1) |
| neg_features = torch.nn.functional.normalize(neg_features, dim=1) |
|
|
| |
| lang_pos_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], pos_features) |
| lang_neg_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], neg_features) |
|
|
| lang_matrix = torch.cat([lang_pos_scores.unsqueeze(-1), lang_neg_scores.unsqueeze(1).repeat(1, lang_pos_scores.shape[1], 1)], dim=2) |
| lang_labels = torch.zeros(lang_matrix.shape[1], dtype=torch.long).cuda() |
| lang_labels = lang_labels.unsqueeze(0).repeat(lang_matrix.shape[0], 1) |
|
|
| lang_score = torch.softmax(lang_matrix, -1) |
| lang_score = 1.0 - lang_score[:, :, 0] |
|
|
| pos_pos_scores = torch.einsum("bq,bqn->bn", avg_pos_feature, pos_features) |
| pos_neg_scores = torch.einsum("bqn,bqm->bnm", pos_features, neg_features) |
|
|
| pos_matrix = torch.cat([pos_pos_scores.unsqueeze(-1), pos_neg_scores], dim=2) |
| pos_labels = torch.zeros(pos_matrix.shape[1], dtype=torch.long).cuda() |
| pos_labels = pos_labels.unsqueeze(0).repeat(pos_matrix.shape[0], 1) |
|
|
| pos_score = torch.softmax(pos_matrix, -1) |
| pos_score = 1.0 - pos_score[:, :, 0] |
| |
|
|
| |
| neg_neg_scores = torch.einsum("bq,bqn->bn", avg_neg_feature, neg_features) |
| neg_pos_scores = torch.einsum("bqn,bqm->bnm", neg_features, pos_features) |
|
|
| neg_matrix = torch.cat([neg_neg_scores.unsqueeze(-1), neg_pos_scores], dim=2) |
| neg_labels = torch.zeros(neg_matrix.shape[1], dtype=torch.long).cuda() |
| neg_labels = neg_labels.unsqueeze(0).repeat(neg_matrix.shape[0], 1) |
|
|
| neg_score = torch.softmax(neg_matrix, -1) |
| neg_score = 1.0 - neg_score[:, :, 0] |
| |
|
|
| pos_loss = (torch.pow(pos_score, args.plic_pos_alpha) * torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/args.plic_pos_temp, pos_labels.view(-1), reduction='none')).mean() |
| neg_loss = (torch.pow(neg_score, args.plic_neg_alpha) * torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/args.plic_neg_temp, neg_labels.view(-1), reduction='none')).mean() |
|
|
| lang_loss = (torch.pow(lang_score, args.plic_lang_alpha) * torch.nn.functional.cross_entropy(lang_matrix.view(-1, lang_matrix.shape[-1])/args.plic_lang_temp, lang_labels.view(-1), reduction='none')).mean() |
|
|
| plic_pos_loss += pos_loss |
| plic_neg_loss += neg_loss |
| plic_lang_loss += lang_loss |
| |
| plic_pos_loss = (args.plic_pos_weight * plic_pos_loss) / B |
| plic_neg_loss = (args.plic_neg_weight * plic_neg_loss) / B |
| plic_lang_loss = (args.plic_lang_weight * plic_lang_loss) / B |
| plic_loss = plic_pos_loss + plic_neg_loss +plic_lang_loss |
|
|
|
|
| |
| losses = criterion(output, target) |
| weight_dict = criterion.weight_dict |
| |
| loss_ce = 0.0 |
| loss_dice = 0.0 |
| loss_mask = 0.0 |
| for k in list(losses.keys()): |
| if k in weight_dict: |
| losses[k] *= criterion.weight_dict[k] |
| if '_ce' in k: |
| loss_ce += losses[k] |
| elif '_dice' in k: |
| loss_dice += losses[k] |
| else: |
| loss_mask += losses[k] |
| else: |
| |
| losses.pop(k) |
| |
| smlm_loss = args.smlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets) |
| loss = loss_ce + loss_dice + loss_mask + plic_loss + smlm_loss |
|
|
|
|
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| lr_scheduler.step() |
|
|
| torch.cuda.synchronize() |
| train_loss += loss.item() |
| iterations += 1 |
| |
| metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), plic_loss=plic_loss.item(), plic_lang_loss=plic_lang_loss.item(), plic_pos_loss=plic_pos_loss.item(), plic_neg_loss=plic_neg_loss.item(), smlm_loss=smlm_loss.item()) |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| torch.cuda.synchronize() |
|
|
|
|
| def main(args): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| seed = args.seed + utils.get_rank() |
| print('seed', seed) |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| |
|
|
| dataset, num_classes = get_dataset("train", |
| get_transform(args=args), |
| args=args) |
| dataset_test, _ = get_dataset("val", |
| get_transform(args=args), |
| args=args) |
|
|
| |
| print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") |
| num_tasks = utils.get_world_size() |
| global_rank = utils.get_rank() |
| |
| |
| train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, |
| shuffle=True) |
| test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
| |
| data_loader = torch.utils.data.DataLoader( |
| dataset, batch_size=args.batch_size, |
| sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) |
|
|
| data_loader_test = torch.utils.data.DataLoader( |
| dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers) |
|
|
| |
| print(args.model) |
| model = multimodal_segmentation_ppm.__dict__[args.model](pretrained=args.pretrained_swin_weights, |
| args=args) |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| |
| |
| |
| |
|
|
| if args.model != 'lavt_one': |
| model_class = MultiModalBert |
| bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=model.backbone.embed_dim) |
| bert_model.pooler = None |
| |
| bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) |
| |
| |
| else: |
| bert_model = None |
| single_bert_model = None |
|
|
| input_shape = dict() |
| input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) |
| input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) |
| input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) |
| input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) |
|
|
|
|
|
|
| cfg = Dict() |
| cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 |
| cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 |
| cfg.MODEL.MASK_FORMER.NHEADS = 8 |
| cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = args.transformer_enc_layers |
| cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 |
| cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 |
| cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] |
|
|
| cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 |
| cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 |
| cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = args.num_object_queries |
| cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = args.dim_feedforward |
| cfg.MODEL.MASK_FORMER.DEC_LAYERS = args.dec_layers |
| cfg.MODEL.MASK_FORMER.PRE_NORM = False |
|
|
| cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True |
| cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = args.no_object_weight |
| cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = args.class_weight |
| cfg.MODEL.MASK_FORMER.DICE_WEIGHT = args.dice_weight |
| cfg.MODEL.MASK_FORMER.MASK_WEIGHT = args.mask_weight |
|
|
| cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = args.train_num_points |
| cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0 |
| cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75 |
| print(cfg) |
|
|
| maskformer_head = MaskFormerHead(cfg, input_shape) |
| maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) |
| |
| |
| |
| |
|
|
| model = WrapperModel(model.backbone, bert_model, maskformer_head, args) |
| model.cuda() |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) |
| single_model = model.module |
|
|
| |
| deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION |
| no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT |
|
|
| |
| class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT |
| dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT |
| mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT |
| |
|
|
| |
|
|
| matcher = HungarianMatcher( |
| cost_class=class_weight, |
| cost_mask=mask_weight, |
| cost_dice=dice_weight, |
| num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, |
| ) |
|
|
| weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight} |
| if deep_supervision: |
| dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS |
| aux_weight_dict = {} |
| for i in range(dec_layers - 1): |
| aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) |
| weight_dict.update(aux_weight_dict) |
|
|
| losses = ["labels", "masks"] |
| criterion = SetCriterion( |
| cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, |
| matcher=matcher, |
| weight_dict=weight_dict, |
| eos_coef=no_object_weight, |
| losses=losses, |
| num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, |
| oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, |
| importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, |
| device='cuda' |
| ) |
| |
| if args.resume == "auto": |
| last_ckpt = "" |
| for e in range(args.epochs): |
| ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') |
| if os.path.exists(ckpt_path): |
| last_ckpt = ckpt_path |
| args.resume = last_ckpt |
|
|
| |
| if args.resume: |
| checkpoint = torch.load(args.resume, map_location='cpu') |
| single_model.load_state_dict(checkpoint['model']) |
| |
| |
|
|
| |
| backbone_no_decay = list() |
| backbone_decay = list() |
| for name, m in single_model.image_model.named_parameters(): |
| if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: |
| backbone_no_decay.append(m) |
| else: |
| backbone_decay.append(m) |
|
|
| params_to_optimize = [ |
| {'params': backbone_no_decay, 'weight_decay': 0.0}, |
| {'params': backbone_decay}, |
| {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, |
| |
| {"params": reduce(operator.concat, |
| [[p for p in single_model.language_model.encoder.layer[i].parameters() |
| if p.requires_grad] for i in range(10)])}, |
| {"params": single_model.language_model.pwams.parameters()}, |
| {"params": single_model.language_model.res_gates.parameters()}, |
| {"params": single_model.language_model.norms.parameters()}, |
| {"params": single_model.lang_proj.parameters()}, |
| |
| {'params': single_model.mlm_head.parameters()}, |
| {'params': single_model.mlm_vis_proj.parameters()}, |
| {'params': single_model.mlm_lang_proj.parameters()}, |
| {'params': single_model.mlm_transformer.parameters()}, |
| {'params': single_model.mlm_pos_embeds.parameters()}, |
| {'params': single_model.mlm_modal_embeds.parameters()}, |
| {'params': single_model.mlm_mask_embed.parameters()}, |
| {'params': single_model.mlm_pos_mlp.parameters()}, |
| |
| |
| |
| |
| |
| |
| |
| |
| ] |
|
|
|
|
| |
| optimizer = torch.optim.AdamW(params_to_optimize, |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| amsgrad=args.amsgrad |
| ) |
|
|
| |
| lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, |
| lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) |
|
|
| |
| start_time = time.time() |
| iterations = 0 |
| best_oIoU = -0.1 |
|
|
| |
| if args.resume: |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| resume_epoch = checkpoint['epoch'] |
| else: |
| resume_epoch = -999 |
|
|
| |
| for epoch in range(max(0, resume_epoch+1), args.epochs): |
| data_loader.sampler.set_epoch(epoch) |
| train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, |
| iterations, args) |
| iou, overallIoU = evaluate(model, data_loader_test) |
|
|
| print('Average object IoU {}'.format(iou)) |
| print('Overall IoU {}'.format(overallIoU)) |
|
|
|
|
| dict_to_save = {'model': single_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
| 'lr_scheduler': lr_scheduler.state_dict()} |
|
|
| checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch)) |
| utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP') |
| if utils.is_main_process(): |
| os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) |
|
|
| if utils.is_main_process(): |
| ckpt_paths = [] |
| for e in range(args.epochs): |
| ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') |
| print(ckpt_path) |
| if os.path.exists(ckpt_path): |
| ckpt_paths.append(ckpt_path) |
| print(ckpt_paths) |
| for ckpt_path in ckpt_paths[:-args.max_ckpt]: |
| os.remove(ckpt_path) |
| print("remove {:s}".format(ckpt_path)) |
|
|
|
|
| save_checkpoint = (best_oIoU < overallIoU) |
| if save_checkpoint: |
| print('Better epoch: {}\n'.format(epoch)) |
| dict_to_save = {'model': single_model.state_dict(), |
| 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
| 'lr_scheduler': lr_scheduler.state_dict()} |
|
|
| checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id)) |
| utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP') |
| if utils.is_main_process(): |
| os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) |
| best_oIoU = overallIoU |
| torch.cuda.empty_cache() |
|
|
| |
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| print('Training time {}'.format(total_time_str)) |
|
|
|
|
| if __name__ == "__main__": |
| from args import get_parser |
| parser = get_parser() |
| args = parser.parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| utils.init_distributed_mode(args) |
| print('Image size: {}'.format(str(args.img_size))) |
| main(args) |
| |
|
|