| import os |
| import glob |
| import time |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
| from tqdm.notebook import tqdm |
| import matplotlib.pyplot as plt |
| from skimage.color import rgb2lab, lab2rgb |
|
|
| import torch |
| from torch import nn, optim |
| from torchvision import transforms |
| from torchvision.utils import make_grid |
| from torch.utils.data import Dataset, DataLoader |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| import requests |
| import gdown |
|
|
| SIZE = 256 |
|
|
|
|
| def download_from_drive(url , output): |
| try: |
| gdown.download(url, output, quiet=False) |
| return True |
| except: |
| print("Error Occured in Downloading model from Gdrive") |
| return False |
|
|
|
|
| class AverageMeter: |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.count, self.avg, self.sum = [0.0] * 3 |
|
|
| def update(self, val, count=1): |
| self.count += count |
| self.sum += count * val |
| self.avg = self.sum / self.count |
|
|
|
|
| def create_loss_meters(): |
| loss_D_fake = AverageMeter() |
| loss_D_real = AverageMeter() |
| loss_D = AverageMeter() |
| loss_G_GAN = AverageMeter() |
| loss_G_L1 = AverageMeter() |
| loss_G = AverageMeter() |
|
|
| return { |
| "loss_D_fake": loss_D_fake, |
| "loss_D_real": loss_D_real, |
| "loss_D": loss_D, |
| "loss_G_GAN": loss_G_GAN, |
| "loss_G_L1": loss_G_L1, |
| "loss_G": loss_G, |
| } |
|
|
|
|
| def update_losses(model, loss_meter_dict, count): |
| for loss_name, loss_meter in loss_meter_dict.items(): |
| loss = getattr(model, loss_name) |
| loss_meter.update(loss.item(), count=count) |
|
|
|
|
| def lab_to_rgb(L, ab): |
| """ |
| Takes a batch of images |
| """ |
|
|
| L = (L + 1.0) * 50.0 |
| ab = ab * 110.0 |
| Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() |
| rgb_imgs = [] |
| for img in Lab: |
| img_rgb = lab2rgb(img) |
| rgb_imgs.append(img_rgb) |
| return np.stack(rgb_imgs, axis=0) |
|
|
|
|
| def visualize(model, data, save=True): |
| model.net_G.eval() |
| with torch.no_grad(): |
| model.setup_input(data) |
| model.forward() |
| model.net_G.train() |
| fake_color = model.fake_color.detach() |
| real_color = model.ab |
| L = model.L |
| fake_imgs = lab_to_rgb(L, fake_color) |
| real_imgs = lab_to_rgb(L, real_color) |
| fig = plt.figure(figsize=(15, 8)) |
| for i in range(5): |
| ax = plt.subplot(3, 5, i + 1) |
| ax.imshow(L[i][0].cpu(), cmap="gray") |
| ax.axis("off") |
| ax = plt.subplot(3, 5, i + 1 + 5) |
| ax.imshow(fake_imgs[i]) |
| ax.axis("off") |
| ax = plt.subplot(3, 5, i + 1 + 10) |
| ax.imshow(real_imgs[i]) |
| ax.axis("off") |
| plt.show() |
| if save: |
| fig.savefig(f"colorization_{time.time()}.png") |
|
|
|
|
| def log_results(loss_meter_dict): |
| for loss_name, loss_meter in loss_meter_dict.items(): |
| print(f"{loss_name}: {loss_meter.avg:.5f}") |
|
|
|
|
| def create_lab_tensors(image): |
| """ |
| This function receives an image path or a direct image input and creates a dictionary of L and ab tensors. |
| Args: |
| - image: either a path to the image file or a direct image input. |
| Returns: |
| - lab_dict: dictionary containing the L and ab tensors. |
| """ |
| if isinstance(image, str): |
| |
| img = Image.open(image).convert("RGB") |
| else: |
| if isinstance(image, np.ndarray): |
| img = Image.fromarray(image) |
| else: |
| img = image |
| img = img.convert("RGB") |
|
|
| custom_transforms = transforms.Compose( |
| [ |
| transforms.Resize((SIZE, SIZE), Image.BICUBIC), |
| transforms.RandomHorizontalFlip(), |
| ] |
| ) |
| img = custom_transforms(img) |
| img = np.array(img) |
| img_lab = rgb2lab(img).astype("float32") |
| img_lab = transforms.ToTensor()(img_lab) |
| L = img_lab[[0], ...] / 50.0 - 1.0 |
| L = L.unsqueeze(0) |
| ab = img_lab[[1, 2], ...] / 110.0 |
| return {"L": L, "ab": ab} |
|
|
|
|
| def predict_and_visualize_single_image(model, data, save=True): |
| model.net_G.eval() |
| with torch.no_grad(): |
| model.setup_input(data) |
| model.forward() |
| fake_color = model.fake_color.detach() |
| L = model.L |
| fake_imgs = lab_to_rgb(L, fake_color) |
| fig, axs = plt.subplots(1, 2, figsize=(8, 4)) |
| axs[0].imshow(L[0][0].cpu(), cmap="gray") |
| axs[0].set_title("Grey Image") |
| axs[0].axis("off") |
|
|
| axs[1].imshow(fake_imgs[0]) |
| axs[1].set_title("Colored Image") |
| axs[1].axis("off") |
| plt.show() |
| if save: |
| fig.savefig(f"colorization_{time.time()}.png") |
|
|
|
|
| def predict_color(model, image, save=False): |
| """ |
| This function receives an image path or a direct image input and creates a dictionary of L and ab tensors. |
| Args: |
| - model : Pytorch Gray Scale to Colorization Model |
| - image: either a path to the image file or a direct image input. |
| """ |
| data = create_lab_tensors(image) |
| predict_and_visualize_single_image(model, data, save) |
|
|
|
|
| def load_model_with_cpu(model_class, file_path): |
| """ |
| Load PyTorch model from file. |
| |
| Args: |
| model_class (torch.nn.Module): PyTorch model class to load. |
| file_path (str): File path to load the model from. |
| |
| Returns: |
| model (torch.nn.Module): Loaded PyTorch model. |
| """ |
| model = model_class() |
| model.load_state_dict(torch.load(file_path, map_location=torch.device("cpu"))) |
| return model |
|
|
|
|
| def load_model_with_gpu(model_class, file_path): |
| """ |
| Load PyTorch model from file. |
| |
| Args: |
| model_class (torch.nn.Module): PyTorch model class to load. |
| file_path (str): File path to load the model from. |
| |
| Returns: |
| model (torch.nn.Module): Loaded PyTorch model. |
| """ |
| model = model_class() |
| model.load_state_dict(torch.load(file_path)) |
| return model |
|
|