File size: 6,489 Bytes
31677e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | from __future__ import absolute_import
from collections import namedtuple
import numpy as np
import torch
import torch.nn
import torch.nn as nn
import torch.nn.init as init
import torchvision
from torch.autograd import Variable
# Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using mps")
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
# Load pretrained vgg model from torchvision
vgg_pretrained_features = torchvision.models.vgg16(
pretrained=pretrained
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
# Freeze vgg model
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
# Return output of vgg features
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, net="vgg", version="0.1", use_dropout=True):
super(LPIPS, self).__init__()
self.version = version
# Imagenet normalization
self.scaling_layer = ScalingLayer()
########################
# Instantiate vgg model
self.chns = [64, 128, 256, 512, 512]
self.L = len(self.chns)
self.net = vgg16(pretrained=True, requires_grad=False)
# Add 1x1 convolutional Layers
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
self.lins = nn.ModuleList(self.lins)
########################
# Load the weights of trained LPIPS model
import inspect
import os
model_path = os.path.abspath(
os.path.join(
inspect.getfile(self.__init__),
"..",
"weights/v%s/%s.pth" % (version, net),
)
)
print("Loading model from: %s" % model_path)
self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
########################
# Freeze all parameters
self.eval()
for param in self.parameters():
param.requires_grad = False
########################
def forward(self, in0, in1, normalize=False):
# Scale the inputs to -1 to +1 range if needed
if (
normalize
): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
########################
# Normalize the inputs according to imagenet normalization
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
########################
# Get VGG outputs for image0 and image1
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
########################
# Compute Square of Difference for each layer output
for kk in range(self.L):
feats0[kk], feats1[kk] = (
torch.nn.functional.normalize(outs0[kk], dim=1),
torch.nn.functional.normalize(outs1[kk]),
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
########################
# 1x1 convolution followed by spatial average on the square differences
res = [
spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
for kk in range(self.L)
]
val = 0
# Aggregate the results of each layer
for l in range(self.L):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
# Imagnet normalization for (0-1)
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
out = self.model(x)
return out
|