Image_Classifier / trufor_native /models /cmx /builder_np_conf.py
Jatin-tec
Add application file
65d7391
"""
Edited in September 2022
@author: fabrizio.guillaro, davide.cozzolino
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from .utils.init_func import init_weight
import logging
def preprc_imagenet_torch(x):
mean = torch.Tensor([0.485, 0.456, 0.406]).to(x.device)
std = torch.Tensor([0.229, 0.224, 0.225]).to(x.device)
x = (x-mean[None, :, None, None]) / std[None, :, None, None]
return x
def create_backbone(typ, norm_layer):
channels = [64, 128, 320, 512]
if typ == 'mit_b2':
logging.info('Using backbone: Segformer-B2')
from .encoders.dual_segformer import mit_b2 as backbone_
backbone = backbone_(norm_fuse=norm_layer)
else:
raise NotImplementedError('backbone not implemented')
return backbone, channels
class myEncoderDecoder(nn.Module):
def __init__(self, cfg=None, norm_layer=nn.BatchNorm2d):
super(myEncoderDecoder, self).__init__()
self.norm_layer = norm_layer
self.cfg = cfg.MODEL.EXTRA
self.mods = cfg.MODEL.MODS
# import backbone and decoder
self.backbone, self.channels = create_backbone(self.cfg.BACKBONE, norm_layer)
if 'CONF_BACKBONE' in self.cfg:
self.backbone_conf, self.channels_conf = create_backbone(self.cfg.CONF_BACKBONE, norm_layer)
else:
self.backbone_conf = None
if self.cfg.DECODER == 'MLPDecoder':
logging.info('Using MLP Decoder')
from .decoders.MLPDecoder import DecoderHead
self.decode_head = DecoderHead(in_channels=self.channels, num_classes=cfg.DATASET.NUM_CLASSES, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)
if self.cfg.CONF:
self.decode_head_conf = DecoderHead(in_channels=self.channels, num_classes=1, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)
else:
self.decode_head_conf = None
self.conf_detection = None
if self.cfg.DETECTION is not None:
if self.cfg.DETECTION == 'none':
pass
elif self.cfg.DETECTION == 'confpool':
self.conf_detection = 'confpool'
assert self.cfg.CONF
self.detection = nn.Sequential(
nn.Linear(in_features=8, out_features=128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(in_features=128, out_features=1),
)
else:
raise NotImplementedError('Detection mechanism not implemented')
else:
raise NotImplementedError('decoder not implemented')
from ..DnCNN import make_net
num_levels = 17
out_channel = 1
self.dncnn = make_net(3, kernels=[3, ] * num_levels,
features=[64, ] * (num_levels - 1) + [out_channel],
bns=[False, ] + [True, ] * (num_levels - 2) + [False, ],
acts=['relu', ] * (num_levels - 1) + ['linear', ],
dilats=[1, ] * num_levels,
bn_momentum=0.1, padding=1)
if self.cfg.PREPRC == 'imagenet': #RGB (mean and variance)
self.prepro = preprc_imagenet_torch
else:
assert False
self.init_weights(pretrained=cfg.MODEL.PRETRAINED)
def init_weights(self, pretrained=None):
if pretrained:
logging.info('Loading pretrained model: {}'.format(pretrained))
self.backbone.init_weights(pretrained=pretrained)
if self.backbone_conf is not None:
self.backbone_conf.init_weights(pretrained=pretrained)
np_weights = self.cfg.NP_WEIGHTS
assert os.path.isfile(np_weights)
dat = torch.load(np_weights, map_location=torch.device('cpu'))
logging.info(f'Noiseprint++ weights: {np_weights}')
if 'network' in dat:
dat = dat['network']
self.dncnn.load_state_dict(dat)
logging.info('Initing weights ...')
init_weight(self.decode_head, nn.init.kaiming_normal_,
self.norm_layer, self.cfg.BN_EPS, self.cfg.BN_MOMENTUM,
mode='fan_in', nonlinearity='relu')
def encode_decode(self, rgb, modal_x):
if rgb is not None:
orisize = rgb.shape
else:
orisize = modal_x.shape
# cmx
x = self.backbone(rgb, modal_x)
out, feats = self.decode_head(x, return_feats=True)
out = F.interpolate(out, size=orisize[2:], mode='bilinear', align_corners=False)
# confidence
if self.decode_head_conf is not None:
if self.backbone_conf is not None:
x_conf = self.backbone_conf(rgb, modal_x)
else:
x_conf = x # same encoder of Localization Network
conf = self.decode_head_conf(x_conf)
conf = F.interpolate(conf, size=orisize[2:], mode='bilinear', align_corners=False)
else:
conf = None
# detection
if self.conf_detection is not None:
if self.conf_detection == 'confpool':
from .layer_utils import weighted_statistics_pooling
f1 = weighted_statistics_pooling(conf).view(out.shape[0],-1)
f2 = weighted_statistics_pooling(out[:,1:2,:,:]-out[:,0:1,:,:], F.logsigmoid(conf)).view(out.shape[0],-1)
det = self.detection(torch.cat((f1,f2),-1))
else:
assert False
else:
det = None
return out, conf, det
def forward(self, rgb):
# Noiseprint++ extraction
if 'NP++' in self.mods:
modal_x = self.dncnn(rgb)
modal_x = torch.tile(modal_x, (3, 1, 1))
else:
modal_x = None
if self.prepro is not None:
rgb = self.prepro(rgb)
out, conf, det = self.encode_decode(rgb, modal_x)
return out, conf, det, modal_x