Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import lpips # LPIPS library for perceptual loss | |
| class GeneratorLoss(nn.Module): | |
| def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05, | |
| device="cpu"): | |
| super(GeneratorLoss, self).__init__() | |
| self.discriminator_model = discriminator_model | |
| self.l1_weight = l1_weight | |
| self.perceptual_weight = perceptual_weight | |
| self.adversarial_weight = adversarial_weight | |
| self.criterion_l1 = nn.L1Loss() | |
| self.criterion_adversarial = nn.BCEWithLogitsLoss() | |
| self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device) | |
| def forward(self, output, target, source): | |
| # L1 loss | |
| l1_loss = self.criterion_l1(output, target) | |
| # Perceptual loss | |
| perceptual_loss = torch.mean(self.criterion_perceptual(output, target)) | |
| # Adversarial loss | |
| fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) | |
| fake_prediction = self.discriminator_model(fake_input) | |
| adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction)) | |
| # Combine losses | |
| generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \ | |
| self.adversarial_weight * adversarial_loss | |
| return generator_loss, l1_loss, perceptual_loss, adversarial_loss | |
| class DiscriminatorLoss(nn.Module): | |
| def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5): | |
| super(DiscriminatorLoss, self).__init__() | |
| self.discriminator_model = discriminator_model | |
| self.criterion_adversarial = nn.BCEWithLogitsLoss() | |
| self.fake_weight = fake_weight | |
| self.real_weight = real_weight | |
| self.mock_weight = mock_weight | |
| def forward(self, output, target, source): | |
| # Adversarial loss | |
| fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age | |
| real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age | |
| mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age | |
| mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age | |
| mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age | |
| mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age | |
| fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input) | |
| mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1), | |
| self.discriminator_model(mock_input2), | |
| self.discriminator_model(mock_input3), | |
| self.discriminator_model(mock_input4)) | |
| discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) + | |
| self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) + | |
| self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) + | |
| self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) + | |
| self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) + | |
| self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4)) | |
| ) | |
| return discriminator_loss | |