OxO_Image-Repair / loss.py
Gordon-H's picture
Upload 13 files
fd5c0a6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19, VGG19_Weights
from torchvision import transforms
class PerceptualLoss(nn.Module):
"""
Calculates the VGG perceptual loss.
Uses features from the VGG19 network pretrained on ImageNet.
Compares features from specific layers for the generated and target images.
"""
def __init__(self, feature_layers=None, use_l1=True, device='cpu'):
"""
Args:
feature_layers (list of int, optional): Indices of VGG19 feature layers to use.
Defaults correspond to layers before pool1, pool2, pool3, pool4.
Specifically: relu1_1, relu2_1, relu3_1, relu4_1 in many implementations.
VGG19 structure: layer indices relate to `features` module.
use_l1 (bool): If True, use L1 loss between features. If False, use L2 (MSE) loss.
device (str): 'cuda' or 'cpu'.
"""
super(PerceptualLoss, self).__init__()
# Load pre-trained VGG19 model
# Ensure you have torchvision installed: pip install torchvision
try:
# Recommended way with modern torchvision
weights = VGG19_Weights.IMAGENET1K_V1
self.vgg = vgg19(weights=weights).features
self.preprocess = weights.transforms() # Get the preprocessing expected by the model
except AttributeError:
# Fallback for older torchvision versions (might require manual weight download if not cached)
print("Warning: Using older torchvision VGG19 loading method. Consider upgrading torchvision.")
self.vgg = vgg19(pretrained=True).features
# Define standard ImageNet normalization manually if transform isn't available
self.preprocess = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.vgg.eval() # Set VGG to evaluation mode
for param in self.vgg.parameters():
param.requires_grad = False # Freeze VGG parameters
self.vgg = self.vgg.to(device)
self.device = device
# Define the layers to extract features from
# Common choices are layers before max pooling
# VGG19 features structure indices:
# ReLU1_1: 1, ReLU2_1: 6, ReLU3_1: 11, ReLU4_1: 20, ReLU5_1: 29 (Sometimes ReLU5 used too)
if feature_layers is None:
# These indices correspond to the output of Conv layers before MaxPool
# Specifically: conv1_1(0), conv2_1(5), conv3_1(10), conv4_1(19), conv5_1(28)
# Often the ReLU output right after is used: 1, 6, 11, 20, 29
self.feature_layers = {1, 6, 11, 20} # Using ReLU outputs before pooling layers 1-4
# Alternative common set often cited as relu5_4 (index 35 or 36 depending on source):
# self.feature_layers = {35} # Or use a specific high-level layer
else:
self.feature_layers = set(feature_layers)
self.loss_fn = nn.L1Loss() if use_l1 else nn.MSELoss()
print(f"PerceptualLoss: Using VGG19 features from layers: {sorted(list(self.feature_layers))}")
print(f"PerceptualLoss: Using {'L1' if use_l1 else 'L2'} distance.")
def forward(self, generated, target):
"""
Compute the perceptual loss.
Args:
generated (torch.Tensor): The generated image tensor (B, C, H, W). Values [0, 1].
target (torch.Tensor): The target (ground truth) image tensor (B, C, H, W). Values [0, 1].
Returns:
torch.Tensor: The calculated perceptual loss.
"""
# Ensure inputs are on the correct device
generated = generated.to(self.device)
target = target.to(self.device)
# Preprocess images for VGG
# VGG expects inputs normalized based on ImageNet stats
# The transform might handle dtype and range, but let's be explicit
generated_norm = self.preprocess(generated)
target_norm = self.preprocess(target)
# Extract features
loss = 0.0
current_layer_idx = 0
max_needed_layer = max(self.feature_layers) if self.feature_layers else 0
# Iterate through VGG layers, extracting features only from specified layers
for layer in self.vgg:
# Compute features for both images up to the current layer
generated_norm = layer(generated_norm)
target_norm = layer(target_norm)
# If the current layer index is one we want to use for loss calculation
if current_layer_idx in self.feature_layers:
loss += self.loss_fn(generated_norm, target_norm)
# Stop iterating if we've passed the last needed layer
if current_layer_idx >= max_needed_layer:
break
current_layer_idx += 1
return loss
# --- Example Usage (for testing the definition) ---
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create dummy images (Batch Size, Channels, Height, Width)
# Note: Images should be in the range [0, 1] for standard transforms
dummy_generated = torch.rand(2, 3, 96, 96).to(device) # Example size (must match target)
dummy_target = torch.rand(2, 3, 96, 96).to(device)
# Instantiate the loss function
# Default layers: {1, 6, 11, 20} (Relu1_1, Relu2_1, Relu3_1, Relu4_1 outputs)
perceptual_loss_l1 = PerceptualLoss(device=device, use_l1=True)
# Example with different layers and L2 loss
# perceptual_loss_l2 = PerceptualLoss(feature_layers={35}, device=device, use_l1=False)
# Calculate loss
loss_val_l1 = perceptual_loss_l1(dummy_generated, dummy_target)
# loss_val_l2 = perceptual_loss_l2(dummy_generated, dummy_target)
print(f"\nCalculated Perceptual Loss (L1, default layers): {loss_val_l1.item()}")
# print(f"Calculated Perceptual Loss (L2, layer 35): {loss_val_l2.item()}")
assert loss_val_l1.item() >= 0, "Loss should be non-negative"
print("\nPerceptualLoss definition test successful!")