Spaces:
Sleeping
Sleeping
File size: 6,241 Bytes
65d7391 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
Edited in September 2022
@author: fabrizio.guillaro, davide.cozzolino
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from .utils.init_func import init_weight
import logging
def preprc_imagenet_torch(x):
mean = torch.Tensor([0.485, 0.456, 0.406]).to(x.device)
std = torch.Tensor([0.229, 0.224, 0.225]).to(x.device)
x = (x-mean[None, :, None, None]) / std[None, :, None, None]
return x
def create_backbone(typ, norm_layer):
channels = [64, 128, 320, 512]
if typ == 'mit_b2':
logging.info('Using backbone: Segformer-B2')
from .encoders.dual_segformer import mit_b2 as backbone_
backbone = backbone_(norm_fuse=norm_layer)
else:
raise NotImplementedError('backbone not implemented')
return backbone, channels
class myEncoderDecoder(nn.Module):
def __init__(self, cfg=None, norm_layer=nn.BatchNorm2d):
super(myEncoderDecoder, self).__init__()
self.norm_layer = norm_layer
self.cfg = cfg.MODEL.EXTRA
self.mods = cfg.MODEL.MODS
# import backbone and decoder
self.backbone, self.channels = create_backbone(self.cfg.BACKBONE, norm_layer)
if 'CONF_BACKBONE' in self.cfg:
self.backbone_conf, self.channels_conf = create_backbone(self.cfg.CONF_BACKBONE, norm_layer)
else:
self.backbone_conf = None
if self.cfg.DECODER == 'MLPDecoder':
logging.info('Using MLP Decoder')
from .decoders.MLPDecoder import DecoderHead
self.decode_head = DecoderHead(in_channels=self.channels, num_classes=cfg.DATASET.NUM_CLASSES, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)
if self.cfg.CONF:
self.decode_head_conf = DecoderHead(in_channels=self.channels, num_classes=1, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)
else:
self.decode_head_conf = None
self.conf_detection = None
if self.cfg.DETECTION is not None:
if self.cfg.DETECTION == 'none':
pass
elif self.cfg.DETECTION == 'confpool':
self.conf_detection = 'confpool'
assert self.cfg.CONF
self.detection = nn.Sequential(
nn.Linear(in_features=8, out_features=128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(in_features=128, out_features=1),
)
else:
raise NotImplementedError('Detection mechanism not implemented')
else:
raise NotImplementedError('decoder not implemented')
from ..DnCNN import make_net
num_levels = 17
out_channel = 1
self.dncnn = make_net(3, kernels=[3, ] * num_levels,
features=[64, ] * (num_levels - 1) + [out_channel],
bns=[False, ] + [True, ] * (num_levels - 2) + [False, ],
acts=['relu', ] * (num_levels - 1) + ['linear', ],
dilats=[1, ] * num_levels,
bn_momentum=0.1, padding=1)
if self.cfg.PREPRC == 'imagenet': #RGB (mean and variance)
self.prepro = preprc_imagenet_torch
else:
assert False
self.init_weights(pretrained=cfg.MODEL.PRETRAINED)
def init_weights(self, pretrained=None):
if pretrained:
logging.info('Loading pretrained model: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained)
if self.backbone_conf is not None:
self.backbone_conf.init_weights(pretrained=pretrained)
np_weights = self.cfg.NP_WEIGHTS
assert os.path.isfile(np_weights)
dat = torch.load(np_weights, map_location=torch.device('cpu'))
logging.info(f'Noiseprint++ weights: {np_weights}')
if 'network' in dat:
dat = dat['network']
self.dncnn.load_state_dict(dat)
logging.info('Initing weights ...')
init_weight(self.decode_head, nn.init.kaiming_normal_,
self.norm_layer, self.cfg.BN_EPS, self.cfg.BN_MOMENTUM,
mode='fan_in', nonlinearity='relu')
def encode_decode(self, rgb, modal_x):
if rgb is not None:
orisize = rgb.shape
else:
orisize = modal_x.shape
# cmx
x = self.backbone(rgb, modal_x)
out, feats = self.decode_head(x, return_feats=True)
out = F.interpolate(out, size=orisize[2:], mode='bilinear', align_corners=False)
# confidence
if self.decode_head_conf is not None:
if self.backbone_conf is not None:
x_conf = self.backbone_conf(rgb, modal_x)
else:
x_conf = x # same encoder of Localization Network
conf = self.decode_head_conf(x_conf)
conf = F.interpolate(conf, size=orisize[2:], mode='bilinear', align_corners=False)
else:
conf = None
# detection
if self.conf_detection is not None:
if self.conf_detection == 'confpool':
from .layer_utils import weighted_statistics_pooling
f1 = weighted_statistics_pooling(conf).view(out.shape[0],-1)
f2 = weighted_statistics_pooling(out[:,1:2,:,:]-out[:,0:1,:,:], F.logsigmoid(conf)).view(out.shape[0],-1)
det = self.detection(torch.cat((f1,f2),-1))
else:
assert False
else:
det = None
return out, conf, det
def forward(self, rgb):
# Noiseprint++ extraction
if 'NP++' in self.mods:
modal_x = self.dncnn(rgb)
modal_x = torch.tile(modal_x, (3, 1, 1))
else:
modal_x = None
if self.prepro is not None:
rgb = self.prepro(rgb)
out, conf, det = self.encode_decode(rgb, modal_x)
return out, conf, det, modal_x
|