Spaces:
Sleeping
Sleeping
File size: 4,211 Bytes
8eec341 c010c34 8eec341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
def init_model(model, device):
model = model.to(device)
model = init_weights(model)
return model
def init_weights(net, init='norm', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and 'Conv' in classname:
if init == 'norm':
nn.init.normal_(m.weight.data, mean=0.0, std=gain)
elif init == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif 'BatchNorm2d' in classname:
nn.init.normal_(m.weight.data, 1., gain)
nn.init.constant_(m.bias.data, 0.)
net.apply(init_func)
print(f"model initialized with {init} initialization")
return net
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def lab_to_rgb(L, ab):
"""
Takes a batch of images
"""
L = (L + 1.) * 50.
ab = ab * 110.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
def build_res_unet(n_input=1, n_output=2, size=256):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
return net_G
net_G = build_res_unet(n_input=1, n_output=2, size=256)
class GANLoss(nn.Module):
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
super().__init__()
self.register_buffer('real_label', torch.tensor(real_label))
self.register_buffer('fake_label', torch.tensor(fake_label))
if gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'lsgan':
self.loss = nn.MSELoss()
def get_labels(self, preds, target_is_real):
if target_is_real:
labels = self.real_label
else:
labels = self.fake_label
return labels.expand_as(preds)
def __call__(self, preds, target_is_real):
labels = self.get_labels(preds, target_is_real)
loss = self.loss(preds, labels)
return loss
def load_model(model_class, file_path):
model = model_class(net_G=net_G)
model.load_state_dict(torch.load(file_path, map_location=device))
resnet_weights = torch.load(file_path, map_location=device)
resnet_weights = torch.load("./model/res18-unet.pt", map_location=device)
resnet_state_dict = resnet_weights['state_dict'] if 'state_dict' in resnet_weights else resnet_weights
model_dict = model.state_dict()
filtered_resnet_state_dict = {k: v for k, v in resnet_state_dict.items() if k in model_dict}
model_dict.update(filtered_resnet_state_dict)
model.load_state_dict(model_dict)
return model
# return model
# model = model_class()
# model.load_state_dict(torch.load(file_path))
# return model
def predict_color(model, image):
# img = Image.open(image)
img = image.resize((256, 256))
# to make it between -1 and 1
img = transforms.ToTensor()(img)[:1] * 2. - 1.
genimg = predict_and_return_image(model, img)
return genimg
def predict_and_return_image(model, img):
model.eval()
with torch.no_grad():
preds = model.net_G(img.unsqueeze(0).to(device))
colorized = lab_to_rgb(img.unsqueeze(0), preds.cpu())[0]
return colorized |