""" Edited in September 2022 @author: fabrizio.guillaro, davide.cozzolino """ import torch import torch.nn as nn import torch.nn.functional as F import os from .utils.init_func import init_weight import logging def preprc_imagenet_torch(x): mean = torch.Tensor([0.485, 0.456, 0.406]).to(x.device) std = torch.Tensor([0.229, 0.224, 0.225]).to(x.device) x = (x-mean[None, :, None, None]) / std[None, :, None, None] return x def create_backbone(typ, norm_layer): channels = [64, 128, 320, 512] if typ == 'mit_b2': logging.info('Using backbone: Segformer-B2') from .encoders.dual_segformer import mit_b2 as backbone_ backbone = backbone_(norm_fuse=norm_layer) else: raise NotImplementedError('backbone not implemented') return backbone, channels class myEncoderDecoder(nn.Module): def __init__(self, cfg=None, norm_layer=nn.BatchNorm2d): super(myEncoderDecoder, self).__init__() self.norm_layer = norm_layer self.cfg = cfg.MODEL.EXTRA self.mods = cfg.MODEL.MODS # import backbone and decoder self.backbone, self.channels = create_backbone(self.cfg.BACKBONE, norm_layer) if 'CONF_BACKBONE' in self.cfg: self.backbone_conf, self.channels_conf = create_backbone(self.cfg.CONF_BACKBONE, norm_layer) else: self.backbone_conf = None if self.cfg.DECODER == 'MLPDecoder': logging.info('Using MLP Decoder') from .decoders.MLPDecoder import DecoderHead self.decode_head = DecoderHead(in_channels=self.channels, num_classes=cfg.DATASET.NUM_CLASSES, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM) if self.cfg.CONF: self.decode_head_conf = DecoderHead(in_channels=self.channels, num_classes=1, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM) else: self.decode_head_conf = None self.conf_detection = None if self.cfg.DETECTION is not None: if self.cfg.DETECTION == 'none': pass elif self.cfg.DETECTION == 'confpool': self.conf_detection = 'confpool' assert self.cfg.CONF self.detection = nn.Sequential( nn.Linear(in_features=8, out_features=128), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(in_features=128, out_features=1), ) else: raise NotImplementedError('Detection mechanism not implemented') else: raise NotImplementedError('decoder not implemented') from ..DnCNN import make_net num_levels = 17 out_channel = 1 self.dncnn = make_net(3, kernels=[3, ] * num_levels, features=[64, ] * (num_levels - 1) + [out_channel], bns=[False, ] + [True, ] * (num_levels - 2) + [False, ], acts=['relu', ] * (num_levels - 1) + ['linear', ], dilats=[1, ] * num_levels, bn_momentum=0.1, padding=1) if self.cfg.PREPRC == 'imagenet': #RGB (mean and variance) self.prepro = preprc_imagenet_torch else: assert False self.init_weights(pretrained=cfg.MODEL.PRETRAINED) def init_weights(self, pretrained=None): if pretrained: logging.info('Loading pretrained model: {}'.format(pretrained)) self.backbone.init_weights(pretrained=pretrained) if self.backbone_conf is not None: self.backbone_conf.init_weights(pretrained=pretrained) np_weights = self.cfg.NP_WEIGHTS assert os.path.isfile(np_weights) dat = torch.load(np_weights, map_location=torch.device('cpu')) logging.info(f'Noiseprint++ weights: {np_weights}') if 'network' in dat: dat = dat['network'] self.dncnn.load_state_dict(dat) logging.info('Initing weights ...') init_weight(self.decode_head, nn.init.kaiming_normal_, self.norm_layer, self.cfg.BN_EPS, self.cfg.BN_MOMENTUM, mode='fan_in', nonlinearity='relu') def encode_decode(self, rgb, modal_x): if rgb is not None: orisize = rgb.shape else: orisize = modal_x.shape # cmx x = self.backbone(rgb, modal_x) out, feats = self.decode_head(x, return_feats=True) out = F.interpolate(out, size=orisize[2:], mode='bilinear', align_corners=False) # confidence if self.decode_head_conf is not None: if self.backbone_conf is not None: x_conf = self.backbone_conf(rgb, modal_x) else: x_conf = x # same encoder of Localization Network conf = self.decode_head_conf(x_conf) conf = F.interpolate(conf, size=orisize[2:], mode='bilinear', align_corners=False) else: conf = None # detection if self.conf_detection is not None: if self.conf_detection == 'confpool': from .layer_utils import weighted_statistics_pooling f1 = weighted_statistics_pooling(conf).view(out.shape[0],-1) f2 = weighted_statistics_pooling(out[:,1:2,:,:]-out[:,0:1,:,:], F.logsigmoid(conf)).view(out.shape[0],-1) det = self.detection(torch.cat((f1,f2),-1)) else: assert False else: det = None return out, conf, det def forward(self, rgb): # Noiseprint++ extraction if 'NP++' in self.mods: modal_x = self.dncnn(rgb) modal_x = torch.tile(modal_x, (3, 1, 1)) else: modal_x = None if self.prepro is not None: rgb = self.prepro(rgb) out, conf, det = self.encode_decode(rgb, modal_x) return out, conf, det, modal_x