Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from tqdm.notebook import tqdm | |
| import matplotlib.pyplot as plt | |
| from skimage.color import rgb2lab, lab2rgb | |
| import torch | |
| from torch import nn, optim | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid | |
| from torch.utils.data import Dataset, DataLoader | |
| def init_model(model, device): | |
| model = model.to(device) | |
| model = init_weights(model) | |
| return model | |
| def init_weights(net, init='norm', gain=0.02): | |
| def init_func(m): | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and 'Conv' in classname: | |
| if init == 'norm': | |
| nn.init.normal_(m.weight.data, mean=0.0, std=gain) | |
| elif init == 'xavier': | |
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init == 'kaiming': | |
| nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| nn.init.constant_(m.bias.data, 0.0) | |
| elif 'BatchNorm2d' in classname: | |
| nn.init.normal_(m.weight.data, 1., gain) | |
| nn.init.constant_(m.bias.data, 0.) | |
| net.apply(init_func) | |
| print(f"model initialized with {init} initialization") | |
| return net | |
| from fastai.vision.learner import create_body | |
| from torchvision.models.resnet import resnet18 | |
| from fastai.vision.models.unet import DynamicUnet | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def lab_to_rgb(L, ab): | |
| """ | |
| Takes a batch of images | |
| """ | |
| L = (L + 1.) * 50. | |
| ab = ab * 110. | |
| Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() | |
| rgb_imgs = [] | |
| for img in Lab: | |
| img_rgb = lab2rgb(img) | |
| rgb_imgs.append(img_rgb) | |
| return np.stack(rgb_imgs, axis=0) | |
| def build_res_unet(n_input=1, n_output=2, size=256): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2) | |
| net_G = DynamicUnet(body, n_output, (size, size)).to(device) | |
| return net_G | |
| net_G = build_res_unet(n_input=1, n_output=2, size=256) | |
| class GANLoss(nn.Module): | |
| def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0): | |
| super().__init__() | |
| self.register_buffer('real_label', torch.tensor(real_label)) | |
| self.register_buffer('fake_label', torch.tensor(fake_label)) | |
| if gan_mode == 'vanilla': | |
| self.loss = nn.BCEWithLogitsLoss() | |
| elif gan_mode == 'lsgan': | |
| self.loss = nn.MSELoss() | |
| def get_labels(self, preds, target_is_real): | |
| if target_is_real: | |
| labels = self.real_label | |
| else: | |
| labels = self.fake_label | |
| return labels.expand_as(preds) | |
| def __call__(self, preds, target_is_real): | |
| labels = self.get_labels(preds, target_is_real) | |
| loss = self.loss(preds, labels) | |
| return loss | |
| def load_model(model_class, file_path): | |
| model = model_class(net_G=net_G) | |
| model.load_state_dict(torch.load(file_path, map_location=device)) | |
| resnet_weights = torch.load(file_path, map_location=device) | |
| resnet_weights = torch.load("./model/res18-unet.pt", map_location=device) | |
| resnet_state_dict = resnet_weights['state_dict'] if 'state_dict' in resnet_weights else resnet_weights | |
| model_dict = model.state_dict() | |
| filtered_resnet_state_dict = {k: v for k, v in resnet_state_dict.items() if k in model_dict} | |
| model_dict.update(filtered_resnet_state_dict) | |
| model.load_state_dict(model_dict) | |
| return model | |
| # return model | |
| # model = model_class() | |
| # model.load_state_dict(torch.load(file_path)) | |
| # return model | |
| def predict_color(model, image): | |
| # img = Image.open(image) | |
| img = image.resize((256, 256)) | |
| # to make it between -1 and 1 | |
| img = transforms.ToTensor()(img)[:1] * 2. - 1. | |
| genimg = predict_and_return_image(model, img) | |
| return genimg | |
| def predict_and_return_image(model, img): | |
| model.eval() | |
| with torch.no_grad(): | |
| preds = model.net_G(img.unsqueeze(0).to(device)) | |
| colorized = lab_to_rgb(img.unsqueeze(0), preds.cpu())[0] | |
| return colorized |