|
|
from .segmenter import CGFormer |
|
|
from .segmenter_rcc import CGFormer_RCC_sbert |
|
|
from .segmenter_sbert import CGFormer_sbert |
|
|
from .segmenter_rz_sbert import CGFormer_Refzomsbert |
|
|
from loguru import logger |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .backbone import MultiModalSwinTransformer |
|
|
|
|
|
|
|
|
def build_model(args): |
|
|
|
|
|
if args.swin_type == 'tiny': |
|
|
embed_dim = 96 |
|
|
depths = [2, 2, 6, 2] |
|
|
num_heads = [3, 6, 12, 24] |
|
|
elif args.swin_type == 'small': |
|
|
embed_dim = 96 |
|
|
depths = [2, 2, 18, 2] |
|
|
num_heads = [3, 6, 12, 24] |
|
|
elif args.swin_type == 'base': |
|
|
embed_dim = 128 |
|
|
depths = [2, 2, 18, 2] |
|
|
num_heads = [4, 8, 16, 32] |
|
|
elif args.swin_type == 'large': |
|
|
embed_dim = 192 |
|
|
depths = [2, 2, 18, 2] |
|
|
num_heads = [6, 12, 24, 48] |
|
|
else: |
|
|
assert False |
|
|
|
|
|
if 'window12' in args.swin_pretrain or args.window12: |
|
|
logger.info('Window size 12!') |
|
|
window_size = 12 |
|
|
else: |
|
|
window_size = 7 |
|
|
|
|
|
if args.mha: |
|
|
mha = args.mha.split('-') |
|
|
mha = [int(a) for a in mha] |
|
|
else: |
|
|
mha = [1, 1, 1, 1] |
|
|
|
|
|
out_indices = (0, 1, 2, 3) |
|
|
backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, |
|
|
window_size=window_size, |
|
|
ape=False, drop_path_rate=0.3, patch_norm=True, |
|
|
out_indices=out_indices, |
|
|
use_checkpoint=False, num_heads_fusion=mha, |
|
|
fusion_drop=args.fusion_drop |
|
|
) |
|
|
if args.swin_pretrain: |
|
|
logger.info('Initializing Multi-modal Swin Transformer weights from ' + args.swin_pretrain) |
|
|
backbone.init_weights(pretrained=args.swin_pretrain) |
|
|
else: |
|
|
logger.info('Randomly initialize Multi-modal Swin Transformer weights.') |
|
|
backbone.init_weights() |
|
|
|
|
|
if ('refcocog' in args.dataset) and args.metric_learning: |
|
|
model = CGFormer_sbert(backbone, args) |
|
|
elif (args.dataset in ['refcoco', 'refcoco+']) and args.metric_learning : |
|
|
model = CGFormer_RCC_sbert(backbone, args) |
|
|
elif ('ref-zom' in args.dataset) and args.metric_learning: |
|
|
print("Loading CGFormer_Refzomsbert...") |
|
|
model = CGFormer_Refzomsbert(backbone, args) |
|
|
else: |
|
|
|
|
|
model = CGFormer(backbone, args) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def build_segmenter(args, DDP=True, OPEN=False): |
|
|
model = build_model(args) |
|
|
if DDP: |
|
|
if args.sync_bn: |
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
single_model = model |
|
|
if OPEN: |
|
|
for p in single_model.backbone.parameters(): |
|
|
p.requires_grad_(False) |
|
|
param_list = [ |
|
|
{ |
|
|
"params": [ |
|
|
p |
|
|
for n, p in single_model.named_parameters() |
|
|
if "backbone" not in n and "text_encoder" not in n and p.requires_grad |
|
|
], |
|
|
|
|
|
}, |
|
|
{ |
|
|
"params": [ |
|
|
p |
|
|
for n, p in single_model.named_parameters() |
|
|
if "pwam" in n and p.requires_grad |
|
|
], |
|
|
|
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in single_model.named_parameters() if "backbone" in n and "pwam" not in n and p.requires_grad], |
|
|
"lr": args.lr_backbone, |
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in single_model.named_parameters() if "text_encoder" in n and p.requires_grad], |
|
|
"lr": args.lr_text_encoder, |
|
|
}, |
|
|
] |
|
|
|
|
|
return model, param_list |
|
|
else: |
|
|
model = model |
|
|
return model |
|
|
|