| """This file contains the definition of the perceptual loss.""" | |
| import torch | |
| from torchvision import models | |
| from torchvision.models.feature_extraction import create_feature_extractor | |
| class PerceptualLoss(torch.nn.Module): | |
| def __init__( | |
| self, | |
| model_name: str = "resnet50", | |
| compute_perceptual_loss_on_logits: bool = True, | |
| ): | |
| """Initialize the perceptual loss. | |
| Args: | |
| model_name -> str: The name of the model to use. | |
| compute_perceptual_loss_on_logits -> bool: Whether to compute the perceptual loss on the logits | |
| or the features. | |
| """ | |
| super().__init__() | |
| if model_name == "resnet50": | |
| model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| return_nodes = {"layer4": "features", "fc": "logits"} | |
| elif model_name == "convnext_s": | |
| model = models.convnext_small( | |
| weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1 | |
| ) | |
| return_nodes = {"features": "features", "classifier": "logits"} | |
| if compute_perceptual_loss_on_logits: | |
| self.model = model | |
| else: | |
| self.model = create_feature_extractor(model, return_nodes=return_nodes) | |
| self.compute_perceptual_loss_on_logits = compute_perceptual_loss_on_logits | |
| self.register_buffer( | |
| "mean", torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None] | |
| ) | |
| self.register_buffer( | |
| "std", torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None] | |
| ) | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| """Compute the perceptual loss. | |
| Args: | |
| input -> torch.Tensor: The input tensor. | |
| target -> torch.Tensor: The target tensor. | |
| Returns: | |
| loss -> torch.Tensor: The perceptual loss. | |
| """ | |
| input = torch.nn.functional.interpolate( | |
| input, size=224, mode="bilinear", antialias=True, align_corners=False | |
| ) | |
| target = torch.nn.functional.interpolate( | |
| target, size=224, mode="bilinear", antialias=True, align_corners=False | |
| ) | |
| input = (input - self.mean) / self.std | |
| target = (target - self.mean) / self.std | |
| features_input = self.model(input) | |
| features_target = self.model(target) | |
| if self.compute_perceptual_loss_on_logits: | |
| loss = torch.nn.functional.mse_loss( | |
| features_input, features_target, reduction="mean" | |
| ) | |
| else: | |
| loss = torch.nn.functional.mse_loss( | |
| features_input["features"], | |
| features_target["features"], | |
| reduction="mean", | |
| ) | |
| loss += torch.nn.functional.mse_loss( | |
| features_input["logits"], features_target["logits"], reduction="mean" | |
| ) | |
| return loss | |
| if __name__ == "__main__": | |
| model = PerceptualLoss() | |
| input = torch.randn(2, 3, 256, 256).clamp_(0, 1) | |
| target = torch.randn(2, 3, 256, 256).clamp_(0, 1) | |
| loss = model(input, target) | |
| print(loss) | |