Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import cv2 | |
| import shutil | |
| import sys | |
| from subprocess import call | |
| import torch | |
| import numpy as np | |
| from skimage import color | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import torch | |
| import dlib | |
| import uuid | |
| uid=uuid.uuid4() | |
| #os.system("pip install dlib") | |
| os.system('bash setup.sh') | |
| def lab2rgb(L, AB): | |
| """Convert an Lab tensor image to a RGB numpy output | |
| Parameters: | |
| L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array) | |
| AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array) | |
| Returns: | |
| rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array) | |
| """ | |
| AB2 = AB * 110.0 | |
| L2 = (L + 1.0) * 50.0 | |
| Lab = torch.cat([L2, AB2], dim=1) | |
| Lab = Lab[0].data.cpu().float().numpy() | |
| Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) | |
| rgb = color.lab2rgb(Lab) * 255 | |
| return rgb | |
| def get_transform(model_name,params=None, grayscale=False, method=Image.BICUBIC): | |
| #params | |
| preprocess = 'resize' | |
| load_size = 256 | |
| crop_size = 256 | |
| transform_list = [] | |
| if grayscale: | |
| transform_list.append(transforms.Grayscale(1)) | |
| if model_name == "Pix2Pix Unet 256": | |
| osize = [load_size, load_size] | |
| transform_list.append(transforms.Resize(osize, method)) | |
| # if 'crop' in preprocess: | |
| # if params is None: | |
| # transform_list.append(transforms.RandomCrop(crop_size)) | |
| return transforms.Compose(transform_list) | |
| def inferRestoration(img, model_name): | |
| #if model_name == "Pix2Pix": | |
| model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixRestoration_unet256') | |
| transform_list = [ | |
| transforms.ToTensor(), | |
| transforms.Resize([256,256], Image.BICUBIC), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ] | |
| transform = transforms.Compose(transform_list) | |
| img = transform(img) | |
| img = torch.unsqueeze(img, 0) | |
| result = model(img) | |
| result = result[0].detach() | |
| result = (result +1)/2.0 | |
| result = transforms.ToPILImage()(result) | |
| return result | |
| def inferColorization(img,model_name): | |
| #print(model_name) | |
| if model_name == "Pix2Pix Resnet 9block": | |
| model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b') | |
| elif model_name == "Pix2Pix Unet 256": | |
| model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_unet256') | |
| elif model_name == "Deoldify": | |
| model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization') | |
| transform_list = [ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ] | |
| transform = transforms.Compose(transform_list) | |
| #a = transforms.ToTensor()(a) | |
| img = img.convert('L') | |
| img = transform(img) | |
| img = torch.unsqueeze(img, 0) | |
| result = model(img) | |
| result = result[0].detach() | |
| result = (result +1)/2.0 | |
| #img = transforms.Grayscale(3)(img) | |
| #img = transforms.ToTensor()(img) | |
| #img = torch.unsqueeze(img, 0) | |
| #result = model(img) | |
| #result = torch.clip(result, min=0, max=1) | |
| image_pil = transforms.ToPILImage()(result) | |
| return image_pil | |
| transform_seq = get_transform(model_name) | |
| img = transform_seq(img) | |
| # if model_name == "Pix2Pix Unet 256": | |
| # img.resize((256,256)) | |
| img = np.array(img) | |
| lab = color.rgb2lab(img).astype(np.float32) | |
| lab_t = transforms.ToTensor()(lab) | |
| A = lab_t[[0], ...] / 50.0 - 1.0 | |
| B = lab_t[[1, 2], ...] / 110.0 | |
| #data = {'A': A, 'B': B, 'A_paths': "", 'B_paths': ""} | |
| L = torch.unsqueeze(A, 0) | |
| #print(L.shape) | |
| ab = model(L) | |
| Lab = lab2rgb(L, ab).astype(np.uint8) | |
| image_pil = Image.fromarray(Lab) | |
| #image_pil.save('test.png') | |
| #print(Lab.shape) | |
| return image_pil | |
| def colorizaition(image,model_name): | |
| image = Image.fromarray(image) | |
| result = inferColorization(image,model_name) | |
| return result | |
| def run_cmd(command): | |
| try: | |
| call(command, shell=True) | |
| except KeyboardInterrupt: | |
| print("Process interrupted") | |
| sys.exit(1) | |
| def run(image,Restoration_mode, Colorizaition_mode): | |
| if Restoration_mode == "BOPBTL": | |
| if os.path.isdir(f"Temp{uid}"): | |
| shutil.rmtree(f"Temp{uid}") | |
| os.makedirs(f"Temp{uid}") | |
| os.makedirs(f"Temp{uid}/input") | |
| print(type(image)) | |
| h,w,c=image.shape | |
| max = 600 | |
| if h>w and h > max: | |
| r = max / float(h) | |
| dim = (int(w * r), max) | |
| elif w>h and w > max: | |
| r = max / float(w) | |
| dim = (max, int(h * r)) | |
| else: | |
| dim = (w,h) | |
| pass | |
| image=cv2.resize(image, dsize=(dim), interpolation=cv2.INTER_LANCZOS4) | |
| cv2.imwrite(f"Temp{uid}/input/input_img.png", image) | |
| command = ("python run.py --input_folder " | |
| + f"Temp{uid}/input" | |
| + " --output_folder " | |
| + f"Temp{uid}" | |
| + " --GPU " | |
| + "-1" | |
| + " --with_scratch") | |
| run_cmd(command) | |
| result_restoration = Image.open(f"Temp{uid}/final_output/input_img.png") | |
| shutil.rmtree("Temp") | |
| elif Restoration_mode == "Pix2Pix": | |
| result_restoration = inferRestoration(image, Restoration_mode) | |
| print("Restoration_mode",Restoration_mode) | |
| result_colorization = inferColorization(result_restoration,Colorizaition_mode) | |
| return result_colorization | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| gr.Column() | |
| with gr.Column(): | |
| im = gr.Image(label="Input Image") | |
| rad1 = gr.Radio(["BOPBTL", "Pix2Pix"], value="BOPBTL") | |
| rad2 = gr.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"],value="Deoldify") | |
| im_btn=gr.Button(label="Restore") | |
| out_im = gr.Image(label="Restored Image") | |
| gr.Column() | |
| im_btn.click(run,[im,rad1,rad2],out_im) | |
| app.queue(concurrency_count=100).launch(show_api=False) |