|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Edited in September 2022 |
|
|
@author: fabrizio.guillaro, davide.cozzolino |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
|
|
|
from lib.models.cmx.init_func import init_weight |
|
|
|
|
|
import logging |
|
|
|
|
|
|
|
|
def preprc_imagenet_torch(x): |
|
|
mean = torch.Tensor([0.485, 0.456, 0.406]).to(x.device) |
|
|
std = torch.Tensor([0.229, 0.224, 0.225]).to(x.device) |
|
|
x = (x-mean[None, :, None, None]) / std[None, :, None, None] |
|
|
return x |
|
|
|
|
|
def preprc_xception_torch(x): |
|
|
return 2.0*x-1.0 |
|
|
|
|
|
|
|
|
def create_backbone(typ, norm_layer): |
|
|
channels = [64, 128, 320, 512] |
|
|
if typ == 'mit_b5': |
|
|
logging.info('Using backbone: Segformer-B5') |
|
|
from .encoders.dual_segformer import mit_b5 as backbone_ |
|
|
backbone = backbone_(norm_fuse=norm_layer) |
|
|
elif typ == 'mit_b4': |
|
|
logging.info('Using backbone: Segformer-B4') |
|
|
from .encoders.dual_segformer import mit_b4 as backbone_ |
|
|
backbone = backbone_(norm_fuse=norm_layer) |
|
|
elif typ == 'mit_b2': |
|
|
logging.info('Using backbone: Segformer-B2') |
|
|
from .encoders.dual_segformer import mit_b2 as backbone_ |
|
|
backbone = backbone_(norm_fuse=norm_layer) |
|
|
elif typ == 'mit_b1': |
|
|
logging.info('Using backbone: Segformer-B1') |
|
|
from .encoders.dual_segformer import mit_b1 as backbone_ |
|
|
backbone = backbone_(norm_fuse=norm_layer) |
|
|
elif typ == 'mit_b0': |
|
|
logging.info('Using backbone: Segformer-B0') |
|
|
channels = [32, 64, 160, 256] |
|
|
from .encoders.dual_segformer import mit_b0 as backbone_ |
|
|
backbone = backbone_(norm_fuse=norm_layer) |
|
|
else: |
|
|
raise NotImplementedError('Backbone not implemented') |
|
|
return backbone, channels |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncoderDecoder(nn.Module): |
|
|
def __init__(self, cfg=None, norm_layer=nn.BatchNorm2d): |
|
|
super(EncoderDecoder, self).__init__() |
|
|
|
|
|
self.norm_layer = norm_layer |
|
|
self.cfg = cfg.MODEL.EXTRA |
|
|
self.mods = cfg.MODEL.MODS |
|
|
|
|
|
|
|
|
if 'NP_OUT_CHANNELS' in self.cfg: |
|
|
self.np_out_ch = self.cfg.NP_OUT_CHANNELS |
|
|
else: |
|
|
self.np_out_ch = 1 |
|
|
|
|
|
modules_list = ['NP++','backbone','loc_head','conf_head','det_head'] |
|
|
for module in self.cfg.MODULES: |
|
|
assert module in modules_list |
|
|
assert 'backbone' in self.cfg.MODULES |
|
|
|
|
|
for module in self.cfg.FIX_MODULES: |
|
|
assert module in modules_list |
|
|
|
|
|
|
|
|
self.backbone, self.channels = create_backbone(self.cfg.BACKBONE, norm_layer) |
|
|
|
|
|
|
|
|
self.decode_head = None |
|
|
self.decode_head_conf = None |
|
|
self.detection = None |
|
|
|
|
|
|
|
|
if self.cfg.DECODER == 'MLPDecoder': |
|
|
logging.info('Using MLP Decoder') |
|
|
from .decoders.MLPDecoder import DecoderHead |
|
|
|
|
|
|
|
|
if 'loc_head' in self.cfg.MODULES: |
|
|
self.decode_head = DecoderHead( |
|
|
in_channels=self.channels, |
|
|
num_classes=cfg.DATASET.NUM_CLASSES, |
|
|
norm_layer=norm_layer, |
|
|
embed_dim=self.cfg.DECODER_EMBED_DIM) |
|
|
|
|
|
|
|
|
if 'conf_head' in self.cfg.MODULES: |
|
|
self.decode_head_conf = DecoderHead( |
|
|
in_channels=self.channels, |
|
|
num_classes=1, |
|
|
norm_layer=norm_layer, |
|
|
embed_dim=self.cfg.DECODER_EMBED_DIM) |
|
|
|
|
|
|
|
|
self.conf_detection = self.cfg.DETECTION |
|
|
if 'det_head' in self.cfg.MODULES: |
|
|
if self.conf_detection == 'confpool': |
|
|
assert 'conf_head' in self.cfg.MODULES |
|
|
self.detection = nn.Sequential( |
|
|
nn.Linear(in_features=8, out_features=128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(p=0.5), |
|
|
nn.Linear(in_features=128, out_features=1), |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError('Detection mechanism not implemented') |
|
|
|
|
|
else: |
|
|
raise NotImplementedError('Decoder not implemented') |
|
|
|
|
|
|
|
|
from lib.models.DnCNN import make_net |
|
|
num_levels = 17 |
|
|
out_channel = self.np_out_ch |
|
|
self.dncnn = make_net(3, kernels=[3, ] * num_levels, |
|
|
features=[64, ] * (num_levels - 1) + [out_channel], |
|
|
bns=[False, ] + [True, ] * (num_levels - 2) + [False, ], |
|
|
acts=['relu', ] * (num_levels - 1) + ['linear', ], |
|
|
dilats=[1, ] * num_levels, |
|
|
bn_momentum=0.1, padding=1) |
|
|
|
|
|
if self.cfg.PREPRC is None or self.cfg.PREPRC == 'none': |
|
|
self.prepro = None |
|
|
elif self.cfg.PREPRC == 'imagenet': |
|
|
self.prepro = preprc_imagenet_torch |
|
|
elif self.cfg.PREPRC == 'xception': |
|
|
self.prepro = preprc_xception_torch |
|
|
else: |
|
|
assert False |
|
|
|
|
|
|
|
|
self.init_weights(pretrained=cfg.MODEL.PRETRAINED) |
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, pretrained=None): |
|
|
|
|
|
|
|
|
if 'NP_WEIGHTS' in self.cfg and not self.cfg.NP_WEIGHTS == '' and self.cfg.NP_WEIGHTS is not None: |
|
|
np_weights = self.cfg.NP_WEIGHTS |
|
|
assert os.path.isfile(np_weights) |
|
|
dat = torch.load(np_weights, map_location=torch.device('cpu'))['network'] |
|
|
logging.info(f'Noiseprint++ weights: {np_weights}') |
|
|
self.dncnn.load_state_dict(dat) |
|
|
|
|
|
|
|
|
if pretrained: |
|
|
logging.info('Loading backbone model: {}'.format(pretrained)) |
|
|
assert os.path.isfile(pretrained) |
|
|
self.backbone.init_weights(pretrained=pretrained) |
|
|
|
|
|
|
|
|
logging.info('Initing heads weights ...') |
|
|
if self.decode_head: |
|
|
init_weight(self.decode_head, nn.init.kaiming_normal_, |
|
|
self.norm_layer, self.cfg.BN_EPS, self.cfg.BN_MOMENTUM, |
|
|
mode='fan_in', nonlinearity='relu') |
|
|
|
|
|
if self.decode_head_conf: |
|
|
init_weight(self.decode_head_conf, nn.init.kaiming_normal_, |
|
|
self.norm_layer, self.cfg.BN_EPS, self.cfg.BN_MOMENTUM, |
|
|
mode='fan_in', nonlinearity='relu') |
|
|
|
|
|
|
|
|
|
|
|
if 'NP++' in self.cfg.FIX_MODULES: |
|
|
for param in self.dncnn.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if 'backbone' in self.cfg.FIX_MODULES: |
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if 'loc_head' in self.cfg.FIX_MODULES: |
|
|
for param in self.decode_head.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if 'conf_head' in self.cfg.FIX_MODULES: |
|
|
for param in self.decode_head_conf.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
def encode_decode(self, rgb, modal_x): |
|
|
|
|
|
if rgb is not None: |
|
|
orisize = rgb.shape |
|
|
else: |
|
|
orisize = modal_x.shape |
|
|
|
|
|
|
|
|
if 'backbone' in self.cfg.FIX_MODULES: |
|
|
with torch.no_grad(): |
|
|
self.backbone.eval() |
|
|
x = self.backbone(rgb, modal_x) |
|
|
else: |
|
|
x = self.backbone(rgb, modal_x) |
|
|
|
|
|
|
|
|
|
|
|
if 'loc_head' in self.cfg.FIX_MODULES: |
|
|
with torch.no_grad(): |
|
|
self.decode_head.eval() |
|
|
out = self.decode_head(x) |
|
|
else: |
|
|
out = self.decode_head(x) |
|
|
|
|
|
out = F.interpolate(out, size=orisize[2:], mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
|
|
|
|
if self.decode_head_conf: |
|
|
if 'conf_head' in self.cfg.FIX_MODULES: |
|
|
with torch.no_grad(): |
|
|
self.decode_head_conf.eval() |
|
|
conf = self.decode_head_conf(x) |
|
|
else: |
|
|
conf = self.decode_head_conf(x) |
|
|
conf = F.interpolate(conf, size=orisize[2:], mode='bilinear', align_corners=False) |
|
|
else: |
|
|
conf = None |
|
|
|
|
|
|
|
|
|
|
|
if self.detection: |
|
|
if self.conf_detection == 'confpool': |
|
|
from .layer_utils import weighted_statistics_pooling |
|
|
f1 = weighted_statistics_pooling(conf).view(out.shape[0],-1) |
|
|
f2 = weighted_statistics_pooling(out[:,1:2,:,:]-out[:,0:1,:,:], F.logsigmoid(conf)).view(out.shape[0],-1) |
|
|
det = self.detection(torch.cat((f1,f2),-1)) |
|
|
else: |
|
|
assert False |
|
|
else: |
|
|
det = None |
|
|
|
|
|
return out, conf, det |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, rgb, save_np=False): |
|
|
|
|
|
|
|
|
|
|
|
if 'NP++' in self.mods: |
|
|
if 'NP++' in self.cfg.FIX_MODULES: |
|
|
with torch.no_grad(): |
|
|
self.dncnn.eval() |
|
|
modal_x = self.dncnn(rgb) |
|
|
else: |
|
|
modal_x = self.dncnn(rgb) |
|
|
|
|
|
if self.np_out_ch == 1: |
|
|
modal_x = torch.tile(modal_x, (3, 1, 1)) |
|
|
else: |
|
|
assert self.np_out_ch == 3 |
|
|
else: |
|
|
modal_x = None |
|
|
|
|
|
|
|
|
if 'RGB' not in self.mods: |
|
|
rgb = None |
|
|
|
|
|
elif self.prepro is not None: |
|
|
rgb = self.prepro(rgb) |
|
|
|
|
|
|
|
|
out, conf, det = self.encode_decode(rgb, modal_x) |
|
|
|
|
|
if save_np: |
|
|
return out, conf, det, modal_x |
|
|
else: |
|
|
return out, conf, det, None |
|
|
|