| import os |
| import glob |
| import time |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
| from tqdm.notebook import tqdm |
| import matplotlib.pyplot as plt |
| from skimage.color import rgb2lab, lab2rgb |
|
|
| import torch |
| from torch import nn, optim |
| from torchvision import transforms |
| from torchvision.utils import make_grid |
| from torch.utils.data import Dataset, DataLoader |
|
|
| from .Generator import UnetBlock, Unet |
| from .Discriminator import PatchDiscriminator |
| from .weights import init_weights |
| from .loss import GANLoss |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def init_model(model, device): |
| model = model.to(device) |
| model = init_weights(model) |
| return model |
|
|
|
|
| class MainModel(nn.Module): |
| def __init__( |
| self, net_G=None, lr_G=2e-4, lr_D=2e-4, beta1=0.5, beta2=0.999, lambda_L1=100.0 |
| ): |
| super().__init__() |
|
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.lambda_L1 = lambda_L1 |
|
|
| if net_G is None: |
| self.net_G = init_model( |
| Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device |
| ) |
| else: |
| self.net_G = net_G.to(self.device) |
| self.net_D = init_model( |
| PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device |
| ) |
| self.GANcriterion = GANLoss(gan_mode="vanilla").to(self.device) |
| self.L1criterion = nn.L1Loss() |
| self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2)) |
| self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2)) |
|
|
| def set_requires_grad(self, model, requires_grad=True): |
| for p in model.parameters(): |
| p.requires_grad = requires_grad |
|
|
| def setup_input(self, data): |
| self.L = data["L"].to(self.device) |
| self.ab = data["ab"].to(self.device) |
|
|
| def forward(self): |
| self.fake_color = self.net_G(self.L) |
|
|
| def backward_D(self): |
| fake_image = torch.cat([self.L, self.fake_color], dim=1) |
| fake_preds = self.net_D(fake_image.detach()) |
| self.loss_D_fake = self.GANcriterion(fake_preds, False) |
| real_image = torch.cat([self.L, self.ab], dim=1) |
| real_preds = self.net_D(real_image) |
| self.loss_D_real = self.GANcriterion(real_preds, True) |
| self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 |
| self.loss_D.backward() |
|
|
| def backward_G(self): |
| fake_image = torch.cat([self.L, self.fake_color], dim=1) |
| fake_preds = self.net_D(fake_image) |
| self.loss_G_GAN = self.GANcriterion(fake_preds, True) |
| self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1 |
| self.loss_G = self.loss_G_GAN + self.loss_G_L1 |
| self.loss_G.backward() |
|
|
| def optimize(self): |
| self.forward() |
| self.net_D.train() |
| self.set_requires_grad(self.net_D, True) |
| self.opt_D.zero_grad() |
| self.backward_D() |
| self.opt_D.step() |
|
|
| self.net_G.train() |
| self.set_requires_grad(self.net_D, False) |
| self.opt_G.zero_grad() |
| self.backward_G() |
| self.opt_G.step() |
|
|