| | import numpy as np |
| | import cv2 |
| | import onnxruntime |
| | import gradio as gr |
| |
|
| |
|
| | def pre_process(img: np.array) -> np.array: |
| | |
| | img = np.transpose(img[:, :, 0:3], (2, 0, 1)) |
| | |
| | img = np.expand_dims(img, axis=0).astype(np.float32) |
| | return img |
| |
|
| |
|
| | def post_process(img: np.array) -> np.array: |
| | |
| | img = np.squeeze(img) |
| | |
| | img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8) |
| | return img |
| |
|
| |
|
| | def inference(model_path: str, img_array: np.array) -> np.array: |
| | options = onnxruntime.SessionOptions() |
| | options.intra_op_num_threads = 1 |
| | options.inter_op_num_threads = 1 |
| | ort_session = onnxruntime.InferenceSession(model_path, options) |
| | ort_inputs = {ort_session.get_inputs()[0].name: img_array} |
| | ort_outs = ort_session.run(None, ort_inputs) |
| |
|
| | return ort_outs[0] |
| |
|
| |
|
| | def convert_pil_to_cv2(image): |
| | |
| | open_cv_image = np.array(image) |
| | |
| | open_cv_image = open_cv_image[:, :, ::-1].copy() |
| | return open_cv_image |
| |
|
| |
|
| | def upscale(image, model): |
| | model_path = f"models/{model}.ort" |
| | img = convert_pil_to_cv2(image) |
| | if img.ndim == 2: |
| | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
| |
|
| | if img.shape[2] == 4: |
| | alpha = img[:, :, 3] |
| | alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) |
| | alpha_output = post_process(inference(model_path, pre_process(alpha))) |
| | alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) |
| |
|
| | img = img[:, :, 0:3] |
| | image_output = post_process(inference(model_path, pre_process(img))) |
| | image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) |
| | image_output[:, :, 3] = alpha_output |
| |
|
| | elif img.shape[2] == 3: |
| | image_output = post_process(inference(model_path, pre_process(img))) |
| |
|
| | return image_output |
| |
|
| |
|
| | css = ".output-image, .input-image, .image-preview {height: 480px !important} " |
| | model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"] |
| |
|
| | gr.Interface( |
| | fn=upscale, |
| | inputs=[ |
| | gr.Image(type="pil", label="Input Image"), |
| | gr.Radio( |
| | model_choices, |
| | type="value", |
| | value=None, |
| | label="Choose Upscaler", |
| | ), |
| | ], |
| | outputs="image", |
| | title="Image Upscaling 🦆", |
| | description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)", |
| | allow_flagging="never", |
| | css=css, |
| | ).launch() |
| |
|