Spaces:
Sleeping
Sleeping
File size: 2,320 Bytes
70be616 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.models import vgg16
from torchvision.transforms import Normalize
class VGGFeatureExtractor(nn.Module):
def __init__(
self,
requires_grad: bool = False,
pretrained_weights: str = "DEFAULT",
) -> None:
super().__init__()
self.norm = Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
self.slice_indices = [(0, 4), (4, 9), (9, 16), (16, 23)]
self.slices = nn.ModuleList([nn.Sequential() for _ in range(len(self.slice_indices))])
self._initialize_slices(pretrained_weights)
self.features = namedtuple("Outputs", [f"layer{i}" for i in range(len(self.slice_indices))])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def lp_norm(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.normalize(x, p=2.0, dim=1)
def _initialize_slices(self, pretrained_weights: str = "DEFAULT") -> None:
features = vgg16(weights=pretrained_weights).features
for slice_idx, (start, end) in enumerate(self.slice_indices):
for i in range(start, end):
self.slices[slice_idx].add_module(str(i), features[i])
def forward(self, x: torch.Tensor) -> namedtuple:
outputs = []
x = self.norm(x)
for slice_model in self.slices:
x = self.lp_norm(slice_model(x))
outputs.append(x)
return self.features(*outputs)
class PerceptualLoss(nn.Module):
def __init__(
self,
requires_grad: bool = False,
pretrained_weights: str = "DEFAULT",
):
super(PerceptualLoss, self).__init__()
self.extractor = VGGFeatureExtractor(
pretrained_weights=pretrained_weights,
requires_grad=requires_grad,
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.mean(
torch.tensor(
[
torch.nn.functional.mse_loss(fx, fy)
for fx, fy in zip(self.extractor(x), self.extractor(y))
]
)
)
|