| | import gradio as gr |
| | import os |
| | import cv2 |
| | import shutil |
| | import sys |
| | from subprocess import call |
| | import torch |
| | import numpy as np |
| | from skimage import color |
| | import torchvision.transforms as transforms |
| | from PIL import Image |
| | import torch |
| |
|
| | os.system("pip install dlib") |
| | os.system('bash setup.sh') |
| |
|
| | def lab2rgb(L, AB): |
| | """Convert an Lab tensor image to a RGB numpy output |
| | Parameters: |
| | L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array) |
| | AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array) |
| | |
| | Returns: |
| | rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array) |
| | """ |
| | AB2 = AB * 110.0 |
| | L2 = (L + 1.0) * 50.0 |
| | Lab = torch.cat([L2, AB2], dim=1) |
| | Lab = Lab[0].data.cpu().float().numpy() |
| | Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) |
| | rgb = color.lab2rgb(Lab) * 255 |
| | return rgb |
| |
|
| | def get_transform(model_name,params=None, grayscale=False, method=Image.BICUBIC): |
| | |
| | preprocess = 'resize' |
| | load_size = 256 |
| | crop_size = 256 |
| | transform_list = [] |
| | if grayscale: |
| | transform_list.append(transforms.Grayscale(1)) |
| | if model_name == "Pix2Pix Unet 256": |
| | osize = [load_size, load_size] |
| | transform_list.append(transforms.Resize(osize, method)) |
| | |
| | |
| | |
| |
|
| | return transforms.Compose(transform_list) |
| |
|
| | def inferRestoration(img, model_name): |
| | |
| | model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixRestoration_unet256') |
| | transform_list = [ |
| | transforms.ToTensor(), |
| | transforms.Resize([256,256], Image.BICUBIC), |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| | ] |
| | transform = transforms.Compose(transform_list) |
| | img = transform(img) |
| | img = torch.unsqueeze(img, 0) |
| | result = model(img) |
| | result = result[0].detach() |
| | result = (result +1)/2.0 |
| | |
| | result = transforms.ToPILImage()(result) |
| | return result |
| |
|
| | def inferColorization(img,model_name): |
| | |
| | if model_name == "Pix2Pix Resnet 9block": |
| | model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b') |
| | elif model_name == "Pix2Pix Unet 256": |
| | model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_unet256') |
| | elif model_name == "Deoldify": |
| | model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization') |
| | transform_list = [ |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5,), (0.5,)) |
| | ] |
| | transform = transforms.Compose(transform_list) |
| | |
| | img = img.convert('L') |
| | img = transform(img) |
| | img = torch.unsqueeze(img, 0) |
| | result = model(img) |
| | |
| | result = result[0].detach() |
| | result = (result +1)/2.0 |
| | |
| | |
| | |
| | |
| | |
| | |
| | image_pil = transforms.ToPILImage()(result) |
| | return image_pil |
| | |
| | transform_seq = get_transform(model_name) |
| | img = transform_seq(img) |
| | |
| | |
| | img = np.array(img) |
| | lab = color.rgb2lab(img).astype(np.float32) |
| | lab_t = transforms.ToTensor()(lab) |
| | A = lab_t[[0], ...] / 50.0 - 1.0 |
| | B = lab_t[[1, 2], ...] / 110.0 |
| | |
| | L = torch.unsqueeze(A, 0) |
| | |
| | ab = model(L) |
| | Lab = lab2rgb(L, ab).astype(np.uint8) |
| | image_pil = Image.fromarray(Lab) |
| | |
| | |
| | return image_pil |
| | |
| | def colorizaition(image,model_name): |
| | image = Image.fromarray(image) |
| | result = inferColorization(image,model_name) |
| | return result |
| |
|
| |
|
| | def run_cmd(command): |
| | try: |
| | call(command, shell=True) |
| | except KeyboardInterrupt: |
| | print("Process interrupted") |
| | sys.exit(1) |
| |
|
| | def run(image,Restoration_mode, Colorizaition_mode): |
| | if Restoration_mode == "BOPBTL": |
| | if os.path.isdir("Temp"): |
| | shutil.rmtree("Temp") |
| | |
| | os.makedirs("Temp") |
| | os.makedirs("Temp/input") |
| | print(type(image)) |
| | cv2.imwrite("Temp/input/input_img.png", image) |
| |
|
| | command = ("python run.py --input_folder " |
| | + "Temp/input" |
| | + " --output_folder " |
| | + "Temp" |
| | + " --GPU " |
| | + "-1" |
| | + " --with_scratch") |
| | run_cmd(command) |
| |
|
| | result_restoration = Image.open("Temp/final_output/input_img.png") |
| | shutil.rmtree("Temp") |
| | |
| | elif Restoration_mode == "Pix2Pix": |
| | result_restoration = inferRestoration(image, Restoration_mode) |
| | print("Restoration_mode",Restoration_mode) |
| |
|
| | result_colorization = inferColorization(result_restoration,Colorizaition_mode) |
| |
|
| | return result_colorization |
| | |
| | examples = [['example/1.jpeg',"BOPBTL","Deoldify"],['example/2.jpg',"BOPBTL","Deoldify"],['example/3.jpg',"BOPBTL","Deoldify"],['example/4.jpg',"BOPBTL","Deoldify"]] |
| | iface = gr.Interface(run, |
| | [gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])], |
| | outputs="image", |
| | examples=examples).launch(debug=True,share=False) |