Spaces:
Build error
Build error
| import time | |
| import numpy as np | |
| from skimage.color import rgb2lab, lab2rgb | |
| import matplotlib.pyplot as plt | |
| from fastai.vision.learner import create_body | |
| from fastai.vision.models.unet import DynamicUnet | |
| from torchvision.models import resnet18 | |
| from torchvision.models import mobilenet_v2 | |
| import torch | |
| class AverageMeter: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.count, self.avg, self.sum = [0.] * 3 | |
| def update(self, val, count=1): | |
| self.count += count | |
| self.sum += count * val | |
| self.avg = self.sum / self.count | |
| def build_res_unet(n_input=1, n_output=2, size=256): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| body = create_body(resnet18(pretrained=True), n_in=n_input, cut=-2) | |
| net_G = DynamicUnet(body, n_output, (size, size)).to(device) | |
| return net_G | |
| def build_mobilenet_unet(n_input=1, n_output=2, size=256): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| mobilenet = mobilenet_v2(pretrained=True) | |
| body = create_body(mobilenet.features, pretrained=True, n_in=n_input, cut=-2) | |
| net_G = DynamicUnet(body, n_output, (size, size)).to(device) | |
| return net_G | |
| def create_loss_meters(): | |
| loss_D_fake = AverageMeter() | |
| loss_D_real = AverageMeter() | |
| loss_D = AverageMeter() | |
| loss_G_GAN = AverageMeter() | |
| loss_G_L1 = AverageMeter() | |
| loss_G = AverageMeter() | |
| return {'loss_D_fake': loss_D_fake, | |
| 'loss_D_real': loss_D_real, | |
| 'loss_D': loss_D, | |
| 'loss_G_GAN': loss_G_GAN, | |
| 'loss_G_L1': loss_G_L1, | |
| 'loss_G': loss_G} | |
| def update_losses(model, loss_meter_dict, count): | |
| for loss_name, loss_meter in loss_meter_dict.items(): | |
| loss = getattr(model, loss_name) | |
| loss_meter.update(loss.item(), count=count) | |
| def lab_to_rgb(L, ab): | |
| """ | |
| Takes a batch of images | |
| """ | |
| L = (L + 1.) * 50. | |
| ab = ab * 110. | |
| Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() | |
| rgb_imgs = [] | |
| for img in Lab: | |
| img_rgb = lab2rgb(img) | |
| rgb_imgs.append(img_rgb) | |
| return np.stack(rgb_imgs, axis=0) | |
| def visualize(model, data, save=True): | |
| model.net_G.eval() | |
| with torch.no_grad(): | |
| model.setup_input(data) | |
| model.forward() | |
| model.net_G.train() | |
| fake_color = model.fake_color.detach() | |
| real_color = model.ab | |
| L = model.L | |
| fake_imgs = lab_to_rgb(L, fake_color) | |
| real_imgs = lab_to_rgb(L, real_color) | |
| fig = plt.figure(figsize=(15, 8)) | |
| for i in range(5): | |
| ax = plt.subplot(3, 5, i + 1) | |
| ax.imshow(L[i][0].cpu(), cmap='gray') | |
| ax.axis("off") | |
| ax = plt.subplot(3, 5, i + 1 + 5) | |
| ax.imshow(fake_imgs[i]) | |
| ax.axis("off") | |
| ax = plt.subplot(3, 5, i + 1 + 10) | |
| ax.imshow(real_imgs[i]) | |
| ax.axis("off") | |
| plt.show() | |
| if save: | |
| fig.savefig(f"colorization_{time.time()}.png") | |
| def log_results(loss_meter_dict): | |
| for loss_name, loss_meter in loss_meter_dict.items(): | |
| print(f"{loss_name}: {loss_meter.avg:.5f}") |