YashNagraj75's picture
Add the dataset and the training script
31677e7
from __future__ import absolute_import
from collections import namedtuple
import numpy as np
import torch
import torch.nn
import torch.nn as nn
import torch.nn.init as init
import torchvision
from torch.autograd import Variable
# Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using mps")
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
# Load pretrained vgg model from torchvision
vgg_pretrained_features = torchvision.models.vgg16(
pretrained=pretrained
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
# Freeze vgg model
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
# Return output of vgg features
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, net="vgg", version="0.1", use_dropout=True):
super(LPIPS, self).__init__()
self.version = version
# Imagenet normalization
self.scaling_layer = ScalingLayer()
########################
# Instantiate vgg model
self.chns = [64, 128, 256, 512, 512]
self.L = len(self.chns)
self.net = vgg16(pretrained=True, requires_grad=False)
# Add 1x1 convolutional Layers
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
self.lins = nn.ModuleList(self.lins)
########################
# Load the weights of trained LPIPS model
import inspect
import os
model_path = os.path.abspath(
os.path.join(
inspect.getfile(self.__init__),
"..",
"weights/v%s/%s.pth" % (version, net),
)
)
print("Loading model from: %s" % model_path)
self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
########################
# Freeze all parameters
self.eval()
for param in self.parameters():
param.requires_grad = False
########################
def forward(self, in0, in1, normalize=False):
# Scale the inputs to -1 to +1 range if needed
if (
normalize
): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
########################
# Normalize the inputs according to imagenet normalization
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
########################
# Get VGG outputs for image0 and image1
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
########################
# Compute Square of Difference for each layer output
for kk in range(self.L):
feats0[kk], feats1[kk] = (
torch.nn.functional.normalize(outs0[kk], dim=1),
torch.nn.functional.normalize(outs1[kk]),
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
########################
# 1x1 convolution followed by spatial average on the square differences
res = [
spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
for kk in range(self.L)
]
val = 0
# Aggregate the results of each layer
for l in range(self.L):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
# Imagnet normalization for (0-1)
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
out = self.model(x)
return out