import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import vgg19, VGG19_Weights from torchvision import transforms class PerceptualLoss(nn.Module): """ Calculates the VGG perceptual loss. Uses features from the VGG19 network pretrained on ImageNet. Compares features from specific layers for the generated and target images. """ def __init__(self, feature_layers=None, use_l1=True, device='cpu'): """ Args: feature_layers (list of int, optional): Indices of VGG19 feature layers to use. Defaults correspond to layers before pool1, pool2, pool3, pool4. Specifically: relu1_1, relu2_1, relu3_1, relu4_1 in many implementations. VGG19 structure: layer indices relate to `features` module. use_l1 (bool): If True, use L1 loss between features. If False, use L2 (MSE) loss. device (str): 'cuda' or 'cpu'. """ super(PerceptualLoss, self).__init__() # Load pre-trained VGG19 model # Ensure you have torchvision installed: pip install torchvision try: # Recommended way with modern torchvision weights = VGG19_Weights.IMAGENET1K_V1 self.vgg = vgg19(weights=weights).features self.preprocess = weights.transforms() # Get the preprocessing expected by the model except AttributeError: # Fallback for older torchvision versions (might require manual weight download if not cached) print("Warning: Using older torchvision VGG19 loading method. Consider upgrading torchvision.") self.vgg = vgg19(pretrained=True).features # Define standard ImageNet normalization manually if transform isn't available self.preprocess = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) self.vgg.eval() # Set VGG to evaluation mode for param in self.vgg.parameters(): param.requires_grad = False # Freeze VGG parameters self.vgg = self.vgg.to(device) self.device = device # Define the layers to extract features from # Common choices are layers before max pooling # VGG19 features structure indices: # ReLU1_1: 1, ReLU2_1: 6, ReLU3_1: 11, ReLU4_1: 20, ReLU5_1: 29 (Sometimes ReLU5 used too) if feature_layers is None: # These indices correspond to the output of Conv layers before MaxPool # Specifically: conv1_1(0), conv2_1(5), conv3_1(10), conv4_1(19), conv5_1(28) # Often the ReLU output right after is used: 1, 6, 11, 20, 29 self.feature_layers = {1, 6, 11, 20} # Using ReLU outputs before pooling layers 1-4 # Alternative common set often cited as relu5_4 (index 35 or 36 depending on source): # self.feature_layers = {35} # Or use a specific high-level layer else: self.feature_layers = set(feature_layers) self.loss_fn = nn.L1Loss() if use_l1 else nn.MSELoss() print(f"PerceptualLoss: Using VGG19 features from layers: {sorted(list(self.feature_layers))}") print(f"PerceptualLoss: Using {'L1' if use_l1 else 'L2'} distance.") def forward(self, generated, target): """ Compute the perceptual loss. Args: generated (torch.Tensor): The generated image tensor (B, C, H, W). Values [0, 1]. target (torch.Tensor): The target (ground truth) image tensor (B, C, H, W). Values [0, 1]. Returns: torch.Tensor: The calculated perceptual loss. """ # Ensure inputs are on the correct device generated = generated.to(self.device) target = target.to(self.device) # Preprocess images for VGG # VGG expects inputs normalized based on ImageNet stats # The transform might handle dtype and range, but let's be explicit generated_norm = self.preprocess(generated) target_norm = self.preprocess(target) # Extract features loss = 0.0 current_layer_idx = 0 max_needed_layer = max(self.feature_layers) if self.feature_layers else 0 # Iterate through VGG layers, extracting features only from specified layers for layer in self.vgg: # Compute features for both images up to the current layer generated_norm = layer(generated_norm) target_norm = layer(target_norm) # If the current layer index is one we want to use for loss calculation if current_layer_idx in self.feature_layers: loss += self.loss_fn(generated_norm, target_norm) # Stop iterating if we've passed the last needed layer if current_layer_idx >= max_needed_layer: break current_layer_idx += 1 return loss # --- Example Usage (for testing the definition) --- if __name__ == '__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Create dummy images (Batch Size, Channels, Height, Width) # Note: Images should be in the range [0, 1] for standard transforms dummy_generated = torch.rand(2, 3, 96, 96).to(device) # Example size (must match target) dummy_target = torch.rand(2, 3, 96, 96).to(device) # Instantiate the loss function # Default layers: {1, 6, 11, 20} (Relu1_1, Relu2_1, Relu3_1, Relu4_1 outputs) perceptual_loss_l1 = PerceptualLoss(device=device, use_l1=True) # Example with different layers and L2 loss # perceptual_loss_l2 = PerceptualLoss(feature_layers={35}, device=device, use_l1=False) # Calculate loss loss_val_l1 = perceptual_loss_l1(dummy_generated, dummy_target) # loss_val_l2 = perceptual_loss_l2(dummy_generated, dummy_target) print(f"\nCalculated Perceptual Loss (L1, default layers): {loss_val_l1.item()}") # print(f"Calculated Perceptual Loss (L2, layer 35): {loss_val_l2.item()}") assert loss_val_l1.item() >= 0, "Loss should be non-negative" print("\nPerceptualLoss definition test successful!")