| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torchvision.models as models |
| from skimage.color import lab2rgb |
| import os |
|
|
| class ColorizationNet(nn.Module): |
| def __init__(self, input_size=128): |
| super(ColorizationNet, self).__init__() |
| MIDLEVEL_FEATURE_SIZE = 128 |
|
|
| |
| resnet = models.resnet18(num_classes=365) |
| |
| resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) |
| |
| self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6]) |
|
|
| |
| self.upsample = nn.Sequential( |
| nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.Upsample(scale_factor=2), |
| nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.Upsample(scale_factor=2), |
| nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(), |
| nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1), |
| nn.Upsample(scale_factor=2) |
| ) |
|
|
| def forward(self, input): |
|
|
| |
| midlevel_features = self.midlevel_resnet(input) |
|
|
| |
| output = self.upsample(midlevel_features) |
| return output |
|
|
|
|
|
|
| def to_rgb(grayscale_input, ab_input, save_path, save_name): |
| |
| C, H, W = grayscale_input.shape |
|
|
| |
| ab_input_resized = torch.nn.functional.interpolate(ab_input.unsqueeze(0), size=(H, W), mode='bilinear', |
| align_corners=False).squeeze(0) |
|
|
| |
| |
| color_image = torch.cat((grayscale_input, ab_input_resized), 0).numpy() |
|
|
| color_image = color_image.transpose((1, 2, 0)) |
| color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100 |
| color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128 |
| color_image = lab2rgb(color_image.astype(np.float64)) |
| grayscale_input = grayscale_input.squeeze().numpy() |
| if save_path is not None and save_name is not None: |
| plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray') |
| plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name)) |
|
|
|
|
| def colorize_single_image(image_path, model, criterion, save_dir, epoch, use_gpu=True): |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
|
|
| transforms.ToTensor() |
| ]) |
| image = Image.open(image_path).convert("L") |
| input_gray = transform(image).unsqueeze(0) |
|
|
| |
| if use_gpu and torch.cuda.is_available(): |
| input_gray = input_gray.cuda() |
| model = model.cuda() |
|
|
| |
| with torch.no_grad(): |
| output_ab = model(input_gray) |
|
|
| |
|
|
| os.makedirs(save_dir, exist_ok=True) |
|
|
| |
| save_paths = { |
| 'grayscale': os.path.join(save_dir, 'gray/'), |
| 'colorized': os.path.join(save_dir, 'color/') |
| } |
| os.makedirs(save_paths['grayscale'], exist_ok=True) |
| os.makedirs(save_paths['colorized'], exist_ok=True) |
|
|
| |
| save_name = f'colorized-epoch-{epoch}.jpg' |
| to_rgb(input_gray[0].cpu(), ab_input=output_ab[0].detach().cpu(), save_path=save_paths, save_name=save_name) |
|
|
| print(f'Colorized image saved in {save_paths["colorized"]}') |
|
|
| |
|
|
| def run_example(image_path, save_dir): |
| use_gpu = torch.cuda.is_available() |
|
|
| model = ColorizationNet() |
| model_path = 'colorization_md1.pth' |
| pretrained = torch.load(model_path, map_location=lambda storage, loc: storage) |
| model.load_state_dict(pretrained) |
| model.eval() |
|
|
| criterion = nn.MSELoss() |
|
|
| with torch.no_grad(): |
| colorize_single_image(image_path, model, criterion, save_dir, epoch=0, use_gpu=use_gpu) |
|
|
| if __name__ == "__main__": |
| |
| image_path = 'example_image.jpg' |
| save_dir = 'results' |
| run_example(image_path, save_dir) |
|
|