| import torch | |
| disc_loss_criterion = torch.nn.BCELoss() | |
| gen_loss_criterion = torch.nn.BCELoss() | |
| real_label = 1 | |
| fake_label = 0 | |
| def discriminator_loss(gen_images, real_images, disc_net): | |
| real = real_images.new_full((real_images.shape[0], 1), real_label) | |
| gen = gen_images.new_full((gen_images.shape[0], 1), fake_label) | |
| realloss = disc_loss_criterion(disc_net(real_images), real) | |
| genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen) | |
| return (genloss + realloss) / 2 | |
| def generator_loss(gen_images, disc_net): | |
| output = disc_net(gen_images) | |
| cats = output.new_full(output.shape, real_label) | |
| return gen_loss_criterion(output, cats) | |