Yash Nagraj
Move the pretrained vgg16 model to another directory
f6182a3
from __future__ import absolute_import
import inspect
import os
from collections import namedtuple
import torch
import torch.nn as nn
import torchvision
# Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
# Load pretrained vgg model from torchvision
vgg_pretrained_features = torchvision.models.vgg16(
pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
# Freeze vgg model
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
# Return output of vgg features
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2,
h_relu3_3, h_relu4_3, h_relu5_3)
return out
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, net='vgg', version='0.1', use_dropout=True):
super(LPIPS, self).__init__()
self.version = version
# Imagenet normalization
self.scaling_layer = ScalingLayer()
########################
# Instantiate vgg model
self.chns = [64, 128, 256, 512, 512]
self.L = len(self.chns)
self.net = vgg16(pretrained=True, requires_grad=False)
# Add 1x1 convolutional Layers
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
self.lins = nn.ModuleList(self.lins)
########################
# Load the weights of trained LPIPS model
model_path = os.path.abspath(
os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
print('Loading model from: %s' % model_path)
self.load_state_dict(torch.load(
model_path, map_location=device), strict=False)
########################
# Freeze all parameters
self.eval()
for param in self.parameters():
param.requires_grad = False
########################
def forward(self, in0, in1, normalize=False):
# Scale the inputs to -1 to +1 range if needed
# turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
if normalize:
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
########################
# Normalize the inputs according to imagenet normalization
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
########################
# Get VGG outputs for image0 and image1
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
########################
# Compute Square of Difference for each layer output
for kk in range(self.L):
feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
########################
# 1x1 convolution followed by spatial average on the square differences
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
for kk in range(self.L)]
val = 0
# Aggregate the results of each layer
for l in range(self.L):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
# Imagnet normalization for (0-1)
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
self.register_buffer('shift', torch.Tensor(
[-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor(
[.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1,
padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
def forward(self, x):
out = self.model(x)
return out