|
|
import utility |
|
|
from types import SimpleNamespace |
|
|
|
|
|
from model import common |
|
|
from loss import discriminator |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
|
|
|
class Adversarial(nn.Module): |
|
|
def __init__(self, args, gan_type): |
|
|
super(Adversarial, self).__init__() |
|
|
self.gan_type = gan_type |
|
|
self.gan_k = args.gan_k |
|
|
self.dis = discriminator.Discriminator(args) |
|
|
if gan_type == 'WGAN_GP': |
|
|
|
|
|
optim_dict = { |
|
|
'optimizer': 'ADAM', |
|
|
'betas': (0, 0.9), |
|
|
'epsilon': 1e-8, |
|
|
'lr': 1e-5, |
|
|
'weight_decay': args.weight_decay, |
|
|
'decay': args.decay, |
|
|
'gamma': args.gamma |
|
|
} |
|
|
optim_args = SimpleNamespace(**optim_dict) |
|
|
else: |
|
|
optim_args = args |
|
|
|
|
|
self.optimizer = utility.make_optimizer(optim_args, self.dis) |
|
|
|
|
|
def forward(self, fake, real): |
|
|
|
|
|
self.loss = 0 |
|
|
fake_detach = fake.detach() |
|
|
for _ in range(self.gan_k): |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
d_fake = self.dis(fake_detach) |
|
|
d_real = self.dis(real) |
|
|
retain_graph = False |
|
|
if self.gan_type == 'GAN': |
|
|
loss_d = self.bce(d_real, d_fake) |
|
|
elif self.gan_type.find('WGAN') >= 0: |
|
|
loss_d = (d_fake - d_real).mean() |
|
|
if self.gan_type.find('GP') >= 0: |
|
|
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) |
|
|
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) |
|
|
hat.requires_grad = True |
|
|
d_hat = self.dis(hat) |
|
|
gradients = torch.autograd.grad( |
|
|
outputs=d_hat.sum(), inputs=hat, |
|
|
retain_graph=True, create_graph=True, only_inputs=True |
|
|
)[0] |
|
|
gradients = gradients.view(gradients.size(0), -1) |
|
|
gradient_norm = gradients.norm(2, dim=1) |
|
|
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() |
|
|
loss_d += gradient_penalty |
|
|
|
|
|
elif self.gan_type == 'RGAN': |
|
|
better_real = d_real - d_fake.mean(dim=0, keepdim=True) |
|
|
better_fake = d_fake - d_real.mean(dim=0, keepdim=True) |
|
|
loss_d = self.bce(better_real, better_fake) |
|
|
retain_graph = True |
|
|
|
|
|
|
|
|
self.loss += loss_d.item() |
|
|
loss_d.backward(retain_graph=retain_graph) |
|
|
self.optimizer.step() |
|
|
|
|
|
if self.gan_type == 'WGAN': |
|
|
for p in self.dis.parameters(): |
|
|
p.data.clamp_(-1, 1) |
|
|
|
|
|
self.loss /= self.gan_k |
|
|
|
|
|
|
|
|
d_fake_bp = self.dis(fake) |
|
|
if self.gan_type == 'GAN': |
|
|
label_real = torch.ones_like(d_fake_bp) |
|
|
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) |
|
|
elif self.gan_type.find('WGAN') >= 0: |
|
|
loss_g = -d_fake_bp.mean() |
|
|
elif self.gan_type == 'RGAN': |
|
|
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) |
|
|
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) |
|
|
loss_g = self.bce(better_fake, better_real) |
|
|
|
|
|
|
|
|
return loss_g |
|
|
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
state_discriminator = self.dis.state_dict(*args, **kwargs) |
|
|
state_optimizer = self.optimizer.state_dict() |
|
|
|
|
|
return dict(**state_discriminator, **state_optimizer) |
|
|
|
|
|
def bce(self, real, fake): |
|
|
label_real = torch.ones_like(real) |
|
|
label_fake = torch.zeros_like(fake) |
|
|
bce_real = F.binary_cross_entropy_with_logits(real, label_real) |
|
|
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) |
|
|
bce_loss = bce_real + bce_fake |
|
|
return bce_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|