import os import glob import time import numpy as np from PIL import Image from pathlib import Path from tqdm.notebook import tqdm import matplotlib.pyplot as plt from skimage.color import rgb2lab, lab2rgb import torch from torch import nn, optim from torchvision import transforms from torchvision.utils import make_grid from torch.utils.data import Dataset, DataLoader def init_model(model, device): model = model.to(device) model = init_weights(model) return model def init_weights(net, init='norm', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and 'Conv' in classname: if init == 'norm': nn.init.normal_(m.weight.data, mean=0.0, std=gain) elif init == 'xavier': nn.init.xavier_normal_(m.weight.data, gain=gain) elif init == 'kaiming': nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif 'BatchNorm2d' in classname: nn.init.normal_(m.weight.data, 1., gain) nn.init.constant_(m.bias.data, 0.) net.apply(init_func) print(f"model initialized with {init} initialization") return net from fastai.vision.learner import create_body from torchvision.models.resnet import resnet18 from fastai.vision.models.unet import DynamicUnet device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def lab_to_rgb(L, ab): """ Takes a batch of images """ L = (L + 1.) * 50. ab = ab * 110. Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() rgb_imgs = [] for img in Lab: img_rgb = lab2rgb(img) rgb_imgs.append(img_rgb) return np.stack(rgb_imgs, axis=0) def build_res_unet(n_input=1, n_output=2, size=256): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2) net_G = DynamicUnet(body, n_output, (size, size)).to(device) return net_G net_G = build_res_unet(n_input=1, n_output=2, size=256) class GANLoss(nn.Module): def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0): super().__init__() self.register_buffer('real_label', torch.tensor(real_label)) self.register_buffer('fake_label', torch.tensor(fake_label)) if gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif gan_mode == 'lsgan': self.loss = nn.MSELoss() def get_labels(self, preds, target_is_real): if target_is_real: labels = self.real_label else: labels = self.fake_label return labels.expand_as(preds) def __call__(self, preds, target_is_real): labels = self.get_labels(preds, target_is_real) loss = self.loss(preds, labels) return loss def load_model(model_class, file_path): model = model_class(net_G=net_G) model.load_state_dict(torch.load(file_path, map_location=device)) resnet_weights = torch.load(file_path, map_location=device) resnet_weights = torch.load("./model/res18-unet.pt", map_location=device) resnet_state_dict = resnet_weights['state_dict'] if 'state_dict' in resnet_weights else resnet_weights model_dict = model.state_dict() filtered_resnet_state_dict = {k: v for k, v in resnet_state_dict.items() if k in model_dict} model_dict.update(filtered_resnet_state_dict) model.load_state_dict(model_dict) return model # return model # model = model_class() # model.load_state_dict(torch.load(file_path)) # return model def predict_color(model, image): # img = Image.open(image) img = image.resize((256, 256)) # to make it between -1 and 1 img = transforms.ToTensor()(img)[:1] * 2. - 1. genimg = predict_and_return_image(model, img) return genimg def predict_and_return_image(model, img): model.eval() with torch.no_grad(): preds = model.net_G(img.unsqueeze(0).to(device)) colorized = lab_to_rgb(img.unsqueeze(0), preds.cpu())[0] return colorized