|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision.models import vgg19, VGG19_Weights |
|
|
from torchvision import transforms |
|
|
|
|
|
class PerceptualLoss(nn.Module): |
|
|
""" |
|
|
Calculates the VGG perceptual loss. |
|
|
|
|
|
Uses features from the VGG19 network pretrained on ImageNet. |
|
|
Compares features from specific layers for the generated and target images. |
|
|
""" |
|
|
def __init__(self, feature_layers=None, use_l1=True, device='cpu'): |
|
|
""" |
|
|
Args: |
|
|
feature_layers (list of int, optional): Indices of VGG19 feature layers to use. |
|
|
Defaults correspond to layers before pool1, pool2, pool3, pool4. |
|
|
Specifically: relu1_1, relu2_1, relu3_1, relu4_1 in many implementations. |
|
|
VGG19 structure: layer indices relate to `features` module. |
|
|
use_l1 (bool): If True, use L1 loss between features. If False, use L2 (MSE) loss. |
|
|
device (str): 'cuda' or 'cpu'. |
|
|
""" |
|
|
super(PerceptualLoss, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
weights = VGG19_Weights.IMAGENET1K_V1 |
|
|
self.vgg = vgg19(weights=weights).features |
|
|
self.preprocess = weights.transforms() |
|
|
except AttributeError: |
|
|
|
|
|
print("Warning: Using older torchvision VGG19 loading method. Consider upgrading torchvision.") |
|
|
self.vgg = vgg19(pretrained=True).features |
|
|
|
|
|
self.preprocess = transforms.Compose([ |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
self.vgg.eval() |
|
|
for param in self.vgg.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.vgg = self.vgg.to(device) |
|
|
self.device = device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if feature_layers is None: |
|
|
|
|
|
|
|
|
|
|
|
self.feature_layers = {1, 6, 11, 20} |
|
|
|
|
|
|
|
|
else: |
|
|
self.feature_layers = set(feature_layers) |
|
|
|
|
|
self.loss_fn = nn.L1Loss() if use_l1 else nn.MSELoss() |
|
|
|
|
|
print(f"PerceptualLoss: Using VGG19 features from layers: {sorted(list(self.feature_layers))}") |
|
|
print(f"PerceptualLoss: Using {'L1' if use_l1 else 'L2'} distance.") |
|
|
|
|
|
|
|
|
def forward(self, generated, target): |
|
|
""" |
|
|
Compute the perceptual loss. |
|
|
|
|
|
Args: |
|
|
generated (torch.Tensor): The generated image tensor (B, C, H, W). Values [0, 1]. |
|
|
target (torch.Tensor): The target (ground truth) image tensor (B, C, H, W). Values [0, 1]. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The calculated perceptual loss. |
|
|
""" |
|
|
|
|
|
generated = generated.to(self.device) |
|
|
target = target.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generated_norm = self.preprocess(generated) |
|
|
target_norm = self.preprocess(target) |
|
|
|
|
|
|
|
|
loss = 0.0 |
|
|
current_layer_idx = 0 |
|
|
max_needed_layer = max(self.feature_layers) if self.feature_layers else 0 |
|
|
|
|
|
|
|
|
for layer in self.vgg: |
|
|
|
|
|
generated_norm = layer(generated_norm) |
|
|
target_norm = layer(target_norm) |
|
|
|
|
|
|
|
|
if current_layer_idx in self.feature_layers: |
|
|
loss += self.loss_fn(generated_norm, target_norm) |
|
|
|
|
|
|
|
|
if current_layer_idx >= max_needed_layer: |
|
|
break |
|
|
|
|
|
current_layer_idx += 1 |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
dummy_generated = torch.rand(2, 3, 96, 96).to(device) |
|
|
dummy_target = torch.rand(2, 3, 96, 96).to(device) |
|
|
|
|
|
|
|
|
|
|
|
perceptual_loss_l1 = PerceptualLoss(device=device, use_l1=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_val_l1 = perceptual_loss_l1(dummy_generated, dummy_target) |
|
|
|
|
|
|
|
|
print(f"\nCalculated Perceptual Loss (L1, default layers): {loss_val_l1.item()}") |
|
|
|
|
|
|
|
|
assert loss_val_l1.item() >= 0, "Loss should be non-negative" |
|
|
print("\nPerceptualLoss definition test successful!") |