| import torch
|
| import torch.nn as nn
|
|
|
| from utils import interact
|
|
|
| import torch.cuda.amp as amp
|
|
|
| class Adversarial(nn.modules.loss._Loss):
|
|
|
|
|
| def __init__(self, args, model, optimizer):
|
| super(Adversarial, self).__init__()
|
| self.args = args
|
| self.model = model.model
|
| self.optimizer = optimizer
|
| self.scaler = amp.GradScaler(
|
| init_scale=self.args.init_scale,
|
| enabled=self.args.amp
|
| )
|
|
|
| self.gan_k = 1
|
|
|
| self.BCELoss = nn.BCEWithLogitsLoss()
|
|
|
| def forward(self, fake, real, training=False):
|
| if training:
|
|
|
| fake_detach = fake.detach()
|
| for _ in range(self.gan_k):
|
| self.optimizer.D.zero_grad()
|
|
|
| with amp.autocast(self.args.amp):
|
| d_fake = self.model.D(fake_detach)
|
| d_real = self.model.D(real)
|
|
|
| label_fake = torch.zeros_like(d_fake)
|
| label_real = torch.ones_like(d_real)
|
|
|
| loss_d = self.BCELoss(d_fake, label_fake) + self.BCELoss(d_real, label_real)
|
|
|
| self.scaler.scale(loss_d).backward(retain_graph=False)
|
| self.scaler.step(self.optimizer.D)
|
| self.scaler.update()
|
| else:
|
| d_real = self.model.D(real)
|
| label_real = torch.ones_like(d_real)
|
|
|
|
|
| d_fake_bp = self.model.D(fake)
|
| loss_g = self.BCELoss(d_fake_bp, label_real)
|
|
|
| return loss_g |