| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| import os |
| import sys |
| import glob |
| import time |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
| from tqdm.notebook import tqdm |
| import matplotlib.pyplot as plt |
| from skimage.color import rgb2lab, lab2rgb |
|
|
| import torch |
| from torch import nn, optim |
| from torchvision import transforms |
| from torchvision.utils import make_grid |
| from torch.utils.data import Dataset, DataLoader |
|
|
|
|
| class GANLoss(nn.Module): |
| def __init__(self, gan_mode="vanilla", real_label=1.0, fake_label=0.0): |
| super().__init__() |
| self.register_buffer("real_label", torch.tensor(real_label)) |
| self.register_buffer("fake_label", torch.tensor(fake_label)) |
| if gan_mode == "vanilla": |
| self.loss = nn.BCEWithLogitsLoss() |
| elif gan_mode == "lsgan": |
| self.loss = nn.MSELoss() |
|
|
| def get_labels(self, preds, target_is_real): |
| if target_is_real: |
| labels = self.real_label |
| else: |
| labels = self.fake_label |
| return labels.expand_as(preds) |
|
|
| def __call__(self, preds, target_is_real): |
| labels = self.get_labels(preds, target_is_real) |
| loss = self.loss(preds, labels) |
| return loss |
|
|