DiffQRCode / diffqrcoder /losses /perceptual_loss.py
sayshara's picture
added diffqrcoder_wrapper
70be616
from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.models import vgg16
from torchvision.transforms import Normalize
class VGGFeatureExtractor(nn.Module):
def __init__(
self,
requires_grad: bool = False,
pretrained_weights: str = "DEFAULT",
) -> None:
super().__init__()
self.norm = Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
self.slice_indices = [(0, 4), (4, 9), (9, 16), (16, 23)]
self.slices = nn.ModuleList([nn.Sequential() for _ in range(len(self.slice_indices))])
self._initialize_slices(pretrained_weights)
self.features = namedtuple("Outputs", [f"layer{i}" for i in range(len(self.slice_indices))])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def lp_norm(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.normalize(x, p=2.0, dim=1)
def _initialize_slices(self, pretrained_weights: str = "DEFAULT") -> None:
features = vgg16(weights=pretrained_weights).features
for slice_idx, (start, end) in enumerate(self.slice_indices):
for i in range(start, end):
self.slices[slice_idx].add_module(str(i), features[i])
def forward(self, x: torch.Tensor) -> namedtuple:
outputs = []
x = self.norm(x)
for slice_model in self.slices:
x = self.lp_norm(slice_model(x))
outputs.append(x)
return self.features(*outputs)
class PerceptualLoss(nn.Module):
def __init__(
self,
requires_grad: bool = False,
pretrained_weights: str = "DEFAULT",
):
super(PerceptualLoss, self).__init__()
self.extractor = VGGFeatureExtractor(
pretrained_weights=pretrained_weights,
requires_grad=requires_grad,
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.mean(
torch.tensor(
[
torch.nn.functional.mse_loss(fx, fy)
for fx, fy in zip(self.extractor(x), self.extractor(y))
]
)
)