File size: 6,066 Bytes
b311ae5 f6182a3 b311ae5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | from __future__ import absolute_import
import inspect
import os
from collections import namedtuple
import torch
import torch.nn as nn
import torchvision
# Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
# Load pretrained vgg model from torchvision
vgg_pretrained_features = torchvision.models.vgg16(
pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
# Freeze vgg model
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
# Return output of vgg features
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2,
h_relu3_3, h_relu4_3, h_relu5_3)
return out
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, net='vgg', version='0.1', use_dropout=True):
super(LPIPS, self).__init__()
self.version = version
# Imagenet normalization
self.scaling_layer = ScalingLayer()
########################
# Instantiate vgg model
self.chns = [64, 128, 256, 512, 512]
self.L = len(self.chns)
self.net = vgg16(pretrained=True, requires_grad=False)
# Add 1x1 convolutional Layers
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
self.lins = nn.ModuleList(self.lins)
########################
# Load the weights of trained LPIPS model
model_path = os.path.abspath(
os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
print('Loading model from: %s' % model_path)
self.load_state_dict(torch.load(
model_path, map_location=device), strict=False)
########################
# Freeze all parameters
self.eval()
for param in self.parameters():
param.requires_grad = False
########################
def forward(self, in0, in1, normalize=False):
# Scale the inputs to -1 to +1 range if needed
# turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
if normalize:
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
########################
# Normalize the inputs according to imagenet normalization
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
########################
# Get VGG outputs for image0 and image1
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
########################
# Compute Square of Difference for each layer output
for kk in range(self.L):
feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
########################
# 1x1 convolution followed by spatial average on the square differences
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
for kk in range(self.L)]
val = 0
# Aggregate the results of each layer
for l in range(self.L):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
# Imagnet normalization for (0-1)
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
self.register_buffer('shift', torch.Tensor(
[-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor(
[.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1,
padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
def forward(self, x):
out = self.model(x)
return out
|