File size: 4,409 Bytes
ea1014e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from .segmenter import CGFormer
from .segmenter_rcc import CGFormer_RCC_sbert
from .segmenter_sbert import CGFormer_sbert
from .segmenter_rz_sbert import CGFormer_Refzomsbert
from loguru import logger
import torch
import torch.nn as nn
from .backbone import MultiModalSwinTransformer


def build_model(args):
   # initialize the SwinTransformer backbone with the specified version
    if args.swin_type == 'tiny':
        embed_dim = 96
        depths = [2, 2, 6, 2]
        num_heads = [3, 6, 12, 24]
    elif args.swin_type == 'small':
        embed_dim = 96
        depths = [2, 2, 18, 2]
        num_heads = [3, 6, 12, 24]
    elif args.swin_type == 'base':
        embed_dim = 128
        depths = [2, 2, 18, 2]
        num_heads = [4, 8, 16, 32]
    elif args.swin_type == 'large':
        embed_dim = 192
        depths = [2, 2, 18, 2]
        num_heads = [6, 12, 24, 48]
    else:
        assert False
    # args.window12 added for test.py because state_dict is loaded after model initialization
    if 'window12' in args.swin_pretrain or args.window12:
        logger.info('Window size 12!')
        window_size = 12
    else:
        window_size = 7

    if args.mha:
        mha = args.mha.split('-')  # if non-empty, then ['a', 'b', 'c', 'd']
        mha = [int(a) for a in mha]
    else:
        mha = [1, 1, 1, 1]

    out_indices = (0, 1, 2, 3)
    backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads,
                                         window_size=window_size,
                                         ape=False, drop_path_rate=0.3, patch_norm=True,
                                         out_indices=out_indices,
                                         use_checkpoint=False, num_heads_fusion=mha,
                                         fusion_drop=args.fusion_drop
                                         )
    if args.swin_pretrain:
        logger.info('Initializing Multi-modal Swin Transformer weights from ' + args.swin_pretrain)
        backbone.init_weights(pretrained=args.swin_pretrain)
    else:
        logger.info('Randomly initialize Multi-modal Swin Transformer weights.')
        backbone.init_weights()
    
    if ('refcocog' in args.dataset) and args.metric_learning:
        model = CGFormer_sbert(backbone, args)
    elif (args.dataset in ['refcoco', 'refcoco+']) and args.metric_learning :
        model = CGFormer_RCC_sbert(backbone, args)
    elif ('ref-zom' in args.dataset) and args.metric_learning:
        print("Loading CGFormer_Refzomsbert...")
        model = CGFormer_Refzomsbert(backbone, args)
    else:
        # for reproduction on every type of dataset!
        model = CGFormer(backbone, args)
    
    return model

        
def build_segmenter(args, DDP=True, OPEN=False):
    model = build_model(args)
    if DDP:
        if args.sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # > deepspeed takes care of this
        # model = nn.parallel.DistributedDataParallel(model.cuda(),
        #                                             device_ids=[args.gpu],
        #                                             find_unused_parameters=True
        #                                             )

        single_model = model #.module
        if OPEN:
            for p in single_model.backbone.parameters():
                p.requires_grad_(False)
        param_list = [
            {
                "params": [
                    p
                    for n, p in single_model.named_parameters()
                    if "backbone" not in n and "text_encoder" not in n and p.requires_grad
                ],

            },
            {
                "params": [
                    p
                    for n, p in single_model.named_parameters()
                    if "pwam" in n and p.requires_grad
                ],

            },
            {
                "params": [p for n, p in single_model.named_parameters() if "backbone" in n and "pwam" not in n and p.requires_grad],
                "lr": args.lr_backbone,
            },
            {
                "params": [p for n, p in single_model.named_parameters() if "text_encoder" in n and p.requires_grad],
                "lr": args.lr_text_encoder,
            },
        ]

        return model, param_list
    else:
        model = model #nn.DataParallel(model).cuda()
        return model