imageColorization / inference.py
shlok123's picture
Upadated Inference.py (#1)
c010c34 verified
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
def init_model(model, device):
model = model.to(device)
model = init_weights(model)
return model
def init_weights(net, init='norm', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and 'Conv' in classname:
if init == 'norm':
nn.init.normal_(m.weight.data, mean=0.0, std=gain)
elif init == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif 'BatchNorm2d' in classname:
nn.init.normal_(m.weight.data, 1., gain)
nn.init.constant_(m.bias.data, 0.)
net.apply(init_func)
print(f"model initialized with {init} initialization")
return net
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def lab_to_rgb(L, ab):
"""
Takes a batch of images
"""
L = (L + 1.) * 50.
ab = ab * 110.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
def build_res_unet(n_input=1, n_output=2, size=256):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
return net_G
net_G = build_res_unet(n_input=1, n_output=2, size=256)
class GANLoss(nn.Module):
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
super().__init__()
self.register_buffer('real_label', torch.tensor(real_label))
self.register_buffer('fake_label', torch.tensor(fake_label))
if gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'lsgan':
self.loss = nn.MSELoss()
def get_labels(self, preds, target_is_real):
if target_is_real:
labels = self.real_label
else:
labels = self.fake_label
return labels.expand_as(preds)
def __call__(self, preds, target_is_real):
labels = self.get_labels(preds, target_is_real)
loss = self.loss(preds, labels)
return loss
def load_model(model_class, file_path):
model = model_class(net_G=net_G)
model.load_state_dict(torch.load(file_path, map_location=device))
resnet_weights = torch.load(file_path, map_location=device)
resnet_weights = torch.load("./model/res18-unet.pt", map_location=device)
resnet_state_dict = resnet_weights['state_dict'] if 'state_dict' in resnet_weights else resnet_weights
model_dict = model.state_dict()
filtered_resnet_state_dict = {k: v for k, v in resnet_state_dict.items() if k in model_dict}
model_dict.update(filtered_resnet_state_dict)
model.load_state_dict(model_dict)
return model
# return model
# model = model_class()
# model.load_state_dict(torch.load(file_path))
# return model
def predict_color(model, image):
# img = Image.open(image)
img = image.resize((256, 256))
# to make it between -1 and 1
img = transforms.ToTensor()(img)[:1] * 2. - 1.
genimg = predict_and_return_image(model, img)
return genimg
def predict_and_return_image(model, img):
model.eval()
with torch.no_grad():
preds = model.net_G(img.unsqueeze(0).to(device))
colorized = lab_to_rgb(img.unsqueeze(0), preds.cpu())[0]
return colorized