File size: 6,241 Bytes
65d7391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Edited in September 2022
@author: fabrizio.guillaro, davide.cozzolino
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from .utils.init_func import init_weight

import logging


def preprc_imagenet_torch(x):
    mean = torch.Tensor([0.485, 0.456, 0.406]).to(x.device)
    std  = torch.Tensor([0.229, 0.224, 0.225]).to(x.device)
    x = (x-mean[None, :, None, None]) / std[None, :, None, None]
    return x


def create_backbone(typ, norm_layer):
    channels = [64, 128, 320, 512]
    if typ == 'mit_b2':
        logging.info('Using backbone: Segformer-B2')
        from .encoders.dual_segformer import mit_b2 as backbone_
        backbone = backbone_(norm_fuse=norm_layer)
    else:
        raise NotImplementedError('backbone not implemented')
    return backbone, channels


class myEncoderDecoder(nn.Module):
    def __init__(self, cfg=None, norm_layer=nn.BatchNorm2d):
        super(myEncoderDecoder, self).__init__()
        
        self.norm_layer = norm_layer
        self.cfg  = cfg.MODEL.EXTRA
        self.mods = cfg.MODEL.MODS
        
        # import backbone and decoder
        self.backbone, self.channels = create_backbone(self.cfg.BACKBONE, norm_layer)
        
        if 'CONF_BACKBONE' in self.cfg:
            self.backbone_conf, self.channels_conf = create_backbone(self.cfg.CONF_BACKBONE, norm_layer)
        else:
            self.backbone_conf = None

        if self.cfg.DECODER == 'MLPDecoder':
            logging.info('Using MLP Decoder')
            from .decoders.MLPDecoder import DecoderHead
            self.decode_head = DecoderHead(in_channels=self.channels, num_classes=cfg.DATASET.NUM_CLASSES, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)

            if self.cfg.CONF:
                self.decode_head_conf = DecoderHead(in_channels=self.channels, num_classes=1, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)
            else:
                self.decode_head_conf = None
            
            self.conf_detection = None
            if self.cfg.DETECTION is not None:
                if self.cfg.DETECTION == 'none':
                    pass
                elif self.cfg.DETECTION == 'confpool':
                    self.conf_detection = 'confpool'
                    assert self.cfg.CONF
                    self.detection  = nn.Sequential(
                            nn.Linear(in_features=8, out_features=128),
                            nn.ReLU(),
                            nn.Dropout(p=0.5),
                            nn.Linear(in_features=128, out_features=1),
                            )
                else:
                    raise NotImplementedError('Detection mechanism not implemented')

        else:
            raise NotImplementedError('decoder not implemented')

        from ..DnCNN import make_net
        num_levels = 17
        out_channel = 1
        self.dncnn = make_net(3, kernels=[3, ] * num_levels,
                       features=[64, ] * (num_levels - 1) + [out_channel],
                       bns=[False, ] + [True, ] * (num_levels - 2) + [False, ],
                       acts=['relu', ] * (num_levels - 1) + ['linear', ],
                       dilats=[1, ] * num_levels,
                       bn_momentum=0.1, padding=1)
        
        if self.cfg.PREPRC == 'imagenet': #RGB (mean and variance)
            self.prepro = preprc_imagenet_torch
        else:
            assert False
        
        self.init_weights(pretrained=cfg.MODEL.PRETRAINED)

        
    
    def init_weights(self, pretrained=None):
        if pretrained:
            logging.info('Loading pretrained model: {}'.format(pretrained))
            self.backbone.init_weights(pretrained=pretrained)
            if self.backbone_conf is not None:
                self.backbone_conf.init_weights(pretrained=pretrained)

            np_weights = self.cfg.NP_WEIGHTS
            assert os.path.isfile(np_weights)
            dat = torch.load(np_weights, map_location=torch.device('cpu'))
            logging.info(f'Noiseprint++ weights: {np_weights}')
            if 'network' in dat:
                dat = dat['network']
            self.dncnn.load_state_dict(dat)

        logging.info('Initing weights ...')
        init_weight(self.decode_head, nn.init.kaiming_normal_,
                    self.norm_layer, self.cfg.BN_EPS, self.cfg.BN_MOMENTUM,
                    mode='fan_in', nonlinearity='relu')




    def encode_decode(self, rgb, modal_x):

        if rgb is not None:
            orisize = rgb.shape
        else:
            orisize = modal_x.shape
        
        # cmx
        x = self.backbone(rgb, modal_x)
        out, feats = self.decode_head(x, return_feats=True)
        out = F.interpolate(out, size=orisize[2:], mode='bilinear', align_corners=False)
        
        # confidence
        if self.decode_head_conf is not None:
            if self.backbone_conf is not None:
                x_conf = self.backbone_conf(rgb, modal_x)
            else:
                x_conf = x # same encoder of Localization Network

            conf = self.decode_head_conf(x_conf)
            conf = F.interpolate(conf, size=orisize[2:], mode='bilinear', align_corners=False)
        else:
            conf = None

        
        # detection
        if self.conf_detection is not None:
            if self.conf_detection == 'confpool':
                from .layer_utils import weighted_statistics_pooling
                f1 = weighted_statistics_pooling(conf).view(out.shape[0],-1)
                f2 = weighted_statistics_pooling(out[:,1:2,:,:]-out[:,0:1,:,:], F.logsigmoid(conf)).view(out.shape[0],-1)
                det = self.detection(torch.cat((f1,f2),-1))
            else:
                assert False
        else:
            det = None
        
        return out, conf, det


    def forward(self, rgb):

        # Noiseprint++ extraction
        if 'NP++' in self.mods:
            modal_x = self.dncnn(rgb)
            modal_x = torch.tile(modal_x, (3, 1, 1))
        else:
            modal_x = None

        if self.prepro is not None:
            rgb = self.prepro(rgb)

        out, conf, det = self.encode_decode(rgb, modal_x)
        return out, conf, det, modal_x