| import gradio as gr |
| import PIL |
| import cv2 |
| import numpy as np |
| from src.deoldify import device |
| from src.deoldify.device_id import DeviceId |
| from src.deoldify.visualize import * |
| from src.app_utils import get_model_bin |
|
|
| device.set(device=DeviceId.CPU) |
|
|
| def load_model(model_dir, option): |
| if option.lower() == 'artistic': |
| model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth' |
| get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth")) |
| colorizer = get_image_colorizer(artistic=True) |
| elif option.lower() == 'stable': |
| model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0" |
| get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth")) |
| colorizer = get_image_colorizer(artistic=False) |
|
|
| return colorizer |
|
|
| def resize_img(input_img, max_size): |
| img = input_img.copy() |
| img_height, img_width = img.shape[0], img.shape[1] |
|
|
| if max(img_height, img_width) > max_size: |
| if img_height > img_width: |
| new_width = img_width * (max_size / img_height) |
| new_height = max_size |
| resized_img = cv2.resize(img, (int(new_width), int(new_height))) |
| return resized_img |
| elif img_height <= img_width: |
| new_width = img_height * (max_size / img_width) |
| new_height = max_size |
| resized_img = cv2.resize(img, (int(new_width), int(new_height))) |
| return resized_img |
|
|
| return img |
|
|
| def colorize_image(input_image, colorizer, img_size=800): |
| pil_img = input_image.convert("RGB") |
| img_rgb = np.array(pil_img) |
| resized_img_rgb = resize_img(img_rgb, img_size) |
| resized_pil_img = PIL.Image.fromarray(resized_img_rgb) |
| output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False) |
|
|
| return output_pil_img |
|
|
| def app(input_image, model='Artistic'): |
| |
| colorizer = load_model('models/', model) |
|
|
| |
| output_image = colorize_image(input_image, colorizer) |
|
|
| return output_image |
|
|
|
|
|
|
| title = "<span style='color: #191970;'>Aiconvert.online</span>" |
|
|
| gr.Interface( |
| app, |
| gr.inputs.Image(type="pil", label="Input"), |
| |
| gr.Image(type="pil", label="Output", show_share_button=False), |
| title=title, |
| css="footer{display:none !important;}", |
| theme=gr.themes.Base(), |
| enable_queue=True, |
| allow_flagging=False |
| ).launch() |
|
|