| |
|
| | from __future__ import absolute_import |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.init as init |
| | from torch.autograd import Variable |
| | import numpy as np |
| | from . import pretrained_networks as pn |
| | import torch.nn |
| |
|
| | import lpips |
| |
|
| | def spatial_average(in_tens, keepdim=True): |
| | return in_tens.mean([2,3],keepdim=keepdim) |
| |
|
| | def upsample(in_tens, out_HW=(64,64)): |
| | in_H, in_W = in_tens.shape[2], in_tens.shape[3] |
| | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) |
| |
|
| | |
| | class LPIPS(nn.Module): |
| | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, |
| | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): |
| | |
| | |
| |
|
| | super(LPIPS, self).__init__() |
| | if(verbose): |
| | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% |
| | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) |
| |
|
| | self.pnet_type = net |
| | self.pnet_tune = pnet_tune |
| | self.pnet_rand = pnet_rand |
| | self.spatial = spatial |
| | self.lpips = lpips |
| | self.version = version |
| | self.scaling_layer = ScalingLayer() |
| |
|
| | if(self.pnet_type in ['vgg','vgg16']): |
| | net_type = pn.vgg16 |
| | self.chns = [64,128,256,512,512] |
| | elif(self.pnet_type=='alex'): |
| | net_type = pn.alexnet |
| | self.chns = [64,192,384,256,256] |
| | elif(self.pnet_type=='squeeze'): |
| | net_type = pn.squeezenet |
| | self.chns = [64,128,256,384,384,512,512] |
| | self.L = len(self.chns) |
| |
|
| | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) |
| |
|
| | if(lpips): |
| | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
| | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
| | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
| | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
| | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
| | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] |
| | if(self.pnet_type=='squeeze'): |
| | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) |
| | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) |
| | self.lins+=[self.lin5,self.lin6] |
| | self.lins = nn.ModuleList(self.lins) |
| |
|
| | if(pretrained): |
| | if(model_path is None): |
| | import inspect |
| | import os |
| | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) |
| |
|
| | if(verbose): |
| | print('Loading model from: %s'%model_path) |
| | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) |
| |
|
| | if(eval_mode): |
| | self.eval() |
| |
|
| | def forward(self, in0, in1, retPerLayer=False, normalize=False): |
| | if normalize: |
| | in0 = 2 * in0 - 1 |
| | in1 = 2 * in1 - 1 |
| |
|
| | |
| | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) |
| | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) |
| | feats0, feats1, diffs = {}, {}, {} |
| |
|
| | for kk in range(self.L): |
| | feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) |
| | diffs[kk] = (feats0[kk]-feats1[kk])**2 |
| |
|
| | if(self.lpips): |
| | if(self.spatial): |
| | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] |
| | else: |
| | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] |
| | else: |
| | if(self.spatial): |
| | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] |
| | else: |
| | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] |
| |
|
| | val = res[0] |
| | for l in range(1,self.L): |
| | val += res[l] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if(retPerLayer): |
| | return (val, res) |
| | else: |
| | return val |
| |
|
| |
|
| | class ScalingLayer(nn.Module): |
| | def __init__(self): |
| | super(ScalingLayer, self).__init__() |
| | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) |
| | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) |
| |
|
| | def forward(self, inp): |
| | return (inp - self.shift) / self.scale |
| |
|
| |
|
| | class NetLinLayer(nn.Module): |
| | ''' A single linear layer which does a 1x1 conv ''' |
| | def __init__(self, chn_in, chn_out=1, use_dropout=False): |
| | super(NetLinLayer, self).__init__() |
| |
|
| | layers = [nn.Dropout(),] if(use_dropout) else [] |
| | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] |
| | self.model = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | class Dist2LogitLayer(nn.Module): |
| | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' |
| | def __init__(self, chn_mid=32, use_sigmoid=True): |
| | super(Dist2LogitLayer, self).__init__() |
| |
|
| | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] |
| | layers += [nn.LeakyReLU(0.2,True),] |
| | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] |
| | layers += [nn.LeakyReLU(0.2,True),] |
| | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] |
| | if(use_sigmoid): |
| | layers += [nn.Sigmoid(),] |
| | self.model = nn.Sequential(*layers) |
| |
|
| | def forward(self,d0,d1,eps=0.1): |
| | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) |
| |
|
| | class BCERankingLoss(nn.Module): |
| | def __init__(self, chn_mid=32): |
| | super(BCERankingLoss, self).__init__() |
| | self.net = Dist2LogitLayer(chn_mid=chn_mid) |
| | |
| | self.loss = torch.nn.BCELoss() |
| |
|
| | def forward(self, d0, d1, judge): |
| | per = (judge+1.)/2. |
| | self.logit = self.net.forward(d0,d1) |
| | return self.loss(self.logit, per) |
| |
|
| | |
| | class FakeNet(nn.Module): |
| | def __init__(self, use_gpu=True, colorspace='Lab'): |
| | super(FakeNet, self).__init__() |
| | self.use_gpu = use_gpu |
| | self.colorspace = colorspace |
| |
|
| | class L2(FakeNet): |
| | def forward(self, in0, in1, retPerLayer=None): |
| | assert(in0.size()[0]==1) |
| |
|
| | if(self.colorspace=='RGB'): |
| | (N,C,X,Y) = in0.size() |
| | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) |
| | return value |
| | elif(self.colorspace=='Lab'): |
| | value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), |
| | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') |
| | ret_var = Variable( torch.Tensor((value,) ) ) |
| | if(self.use_gpu): |
| | ret_var = ret_var.cuda() |
| | return ret_var |
| |
|
| | class DSSIM(FakeNet): |
| |
|
| | def forward(self, in0, in1, retPerLayer=None): |
| | assert(in0.size()[0]==1) |
| |
|
| | if(self.colorspace=='RGB'): |
| | value = lpips.dssim(1.*lpips.tensor2im(in0.data), 1.*lpips.tensor2im(in1.data), range=255.).astype('float') |
| | elif(self.colorspace=='Lab'): |
| | value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), |
| | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') |
| | ret_var = Variable( torch.Tensor((value,) ) ) |
| | if(self.use_gpu): |
| | ret_var = ret_var.cuda() |
| | return ret_var |
| |
|
| | def print_network(net): |
| | num_params = 0 |
| | for param in net.parameters(): |
| | num_params += param.numel() |
| | print('Network',net) |
| | print('Total number of parameters: %d' % num_params) |
| |
|