| import gradio as gr |
|
|
| import torch |
|
|
| import skimage |
| import skimage.io |
| from skimage.transform import rescale, resize |
| from skimage import io, color |
|
|
| import cv2 |
| from colorizer import normalize_lab_channels, torch_normalized_lab_to_rgb |
|
|
| model = torch.load('colorizer.pth') |
| model = model.eval() |
|
|
| def colorize_img( img_float32, model, res=512, border=0.2, apply_blur=True): |
| img = img_float32 |
| |
| |
| border = int(img.shape[0] * border) |
| img2 = cv2.copyMakeBorder(img, border, border, border, border, cv2.BORDER_CONSTANT, value=(1.,1.,1.)) |
| |
| |
| img_resized = resize(img2, (res,res), anti_aliasing=True) |
| |
| |
| if apply_blur: |
| img2 = skimage.filters.gaussian( img_resized, sigma=1, channel_axis=-1 ) |
| else: |
| img2 = img_resized |
| |
| |
| img2 = normalize_lab_channels(color.rgb2lab(img2)) |
| |
| |
| img_resized = normalize_lab_channels(color.rgb2lab(img_resized)) |
| img_resized = torch.from_numpy(img_resized) |
| img_resized = img_resized.permute(2,0,1).unsqueeze(dim=0) |
| |
| |
| x = torch.from_numpy(img2) |
| x = x.permute(2,0,1).unsqueeze(dim=0) |
| x[:,1,:,:] = x[:,0,:,:] |
| x[:,2,:,:] = x[:,0,:,:] |
| |
| x_hat_ab = model( x ) |
| |
| x_hat = img_resized.clone() |
| x_hat[:,1:,:,:] = x_hat_ab.clone() |
| |
| colored_img = torch_normalized_lab_to_rgb( x_hat ) |
| |
| return colored_img.detach().cpu().squeeze().permute(1,2,0).numpy() |
|
|
| def process_image(img): |
| return colorize_img( (img / 255).astype('float32'), model ) |
|
|
| image = gr.inputs.Image() |
| label = gr.outputs.Label() |
| title = "Colorizer" |
| description = "A model that colorizes b&w images." |
| interpretation='default' |
| enable_queue=True |
|
|
| examples = ['ka0001.jpg', 'ka0003.jpg', 'ka0009.jpg', 'ka0010.jpg'] |
| css = ".h-60 {min-height: 512px !important;}" |
|
|
| gr.Interface(fn=process_image, |
| inputs=gr.inputs.Image(), |
| outputs=gr.outputs.Image(), |
| title=title, |
| description=description, |
| css=css, |
| examples=examples).launch(debug=True,enable_queue=True) |