| import torch |
| import torch.nn as nn |
|
|
| from utils import interact |
|
|
| import torch.cuda.amp as amp |
|
|
| class Adversarial(nn.modules.loss._Loss): |
| |
| |
| def __init__(self, args, model, optimizer): |
| super(Adversarial, self).__init__() |
| self.args = args |
| self.model = model.model |
| self.optimizer = optimizer |
| self.scaler = amp.GradScaler( |
| init_scale=self.args.init_scale, |
| enabled=self.args.amp |
| ) |
|
|
| self.gan_k = 1 |
|
|
| self.BCELoss = nn.BCEWithLogitsLoss() |
|
|
| def forward(self, fake, real, training=False): |
| if training: |
| |
| fake_detach = fake.detach() |
| for _ in range(self.gan_k): |
| self.optimizer.D.zero_grad() |
| |
| with amp.autocast(self.args.amp): |
| d_fake = self.model.D(fake_detach) |
| d_real = self.model.D(real) |
|
|
| label_fake = torch.zeros_like(d_fake) |
| label_real = torch.ones_like(d_real) |
|
|
| loss_d = self.BCELoss(d_fake, label_fake) + self.BCELoss(d_real, label_real) |
|
|
| self.scaler.scale(loss_d).backward(retain_graph=False) |
| self.scaler.step(self.optimizer.D) |
| self.scaler.update() |
| else: |
| d_real = self.model.D(real) |
| label_real = torch.ones_like(d_real) |
|
|
| |
| d_fake_bp = self.model.D(fake) |
| loss_g = self.BCELoss(d_fake_bp, label_real) |
|
|
| return loss_g |