| import os |
| import math |
| import numpy as np |
| import onnxruntime as ort |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| |
| MODEL_PATH = "Real-ESRGAN-x4plus.onnx" |
|
|
| |
| |
| sess_opts = ort.SessionOptions() |
| sess_opts.intra_op_num_threads = 4 |
| sess_opts.inter_op_num_threads = 4 |
| session = None |
|
|
| def load_model(): |
| global session |
| if session is None: |
| if not os.path.exists(MODEL_PATH): |
| raise FileNotFoundError(f"Model not found. Please ensure '{MODEL_PATH}' is uploaded.") |
| session = ort.InferenceSession(MODEL_PATH, sess_options=sess_opts, providers=["CPUExecutionProvider"]) |
| return session |
|
|
| |
| |
| def process_tensor(sess, tensor_np): |
| input_name = sess.get_inputs()[0].name |
| patch_nchw = np.transpose(tensor_np, (2, 0, 1))[np.newaxis, ...] |
| out_nchw = sess.run(None, {input_name: patch_nchw})[0] |
| out_nchw = np.squeeze(out_nchw, axis=0) |
| return np.transpose(out_nchw, (1, 2, 0)) |
|
|
| def upscale(input_img: Image.Image, progress=gr.Progress()): |
| if input_img is None: |
| return None |
| |
| sess = load_model() |
| SCALE = 4 |
|
|
| |
| input_shape = sess.get_inputs()[0].shape |
| expected_h = input_shape[2] |
| expected_w = input_shape[3] |
| |
| |
| is_static = isinstance(expected_h, int) and isinstance(expected_w, int) |
| if is_static: |
| BLOCK_H = expected_h |
| BLOCK_W = expected_w |
| else: |
| |
| BLOCK_H = 432 |
| BLOCK_W = 432 |
|
|
| img_rgb = input_img.convert("RGB") |
| arr = np.array(img_rgb).astype(np.float32) / 255.0 |
| h_orig, w_orig, _ = arr.shape |
| |
| pad_size = 16 |
| step_h = BLOCK_H - 2 * pad_size |
| step_w = BLOCK_W - 2 * pad_size |
|
|
| |
| pad_h_req = (step_h - (h_orig % step_h)) % step_h |
| pad_w_req = (step_w - (w_orig % step_w)) % step_w |
| arr_div = np.pad(arr, ((0, pad_h_req), (0, pad_w_req), (0, 0)), mode="reflect") |
| |
| |
| arr_padded = np.pad(arr_div, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode="reflect") |
| |
| h_div, w_div, _ = arr_div.shape |
| out_h_div = h_div * SCALE |
| out_w_div = w_div * SCALE |
| out_arr = np.zeros((out_h_div, out_w_div, 3), dtype=np.float32) |
|
|
| tiles_y = h_div // step_h |
| tiles_x = w_div // step_w |
| total_tiles = tiles_y * tiles_x |
| current_tile = 0 |
|
|
| for i in range(tiles_y): |
| for j in range(tiles_x): |
| current_tile += 1 |
| progress(current_tile / total_tiles, desc=f"Processing strict tile {current_tile}/{total_tiles}...") |
| |
| |
| y_start = i * step_h |
| y_end = y_start + BLOCK_H |
| x_start = j * step_w |
| x_end = x_start + BLOCK_W |
| |
| tile = arr_padded[y_start:y_end, x_start:x_end, :] |
| |
| |
| up_tile = process_tensor(sess, tile) |
| |
| |
| crop_start = pad_size * SCALE |
| crop_end_h = up_tile.shape[0] - (pad_size * SCALE) |
| crop_end_w = up_tile.shape[1] - (pad_size * SCALE) |
| valid_up_tile = up_tile[crop_start:crop_end_h, crop_start:crop_end_w, :] |
| |
| |
| out_y_start = i * step_h * SCALE |
| out_y_end = out_y_start + valid_up_tile.shape[0] |
| out_x_start = j * step_w * SCALE |
| out_x_end = out_x_start + valid_up_tile.shape[1] |
| out_arr[out_y_start:out_y_end, out_x_start:out_x_end, :] = valid_up_tile |
|
|
| |
| final_out = out_arr[0:h_orig * SCALE, 0:w_orig * SCALE, :] |
|
|
| progress(0.95, desc="Finalizing output...") |
| final_out = np.clip(final_out, 0.0, 1.0) |
| return Image.fromarray((final_out * 255.0).round().astype(np.uint8)) |
|
|
| |
| |
| with gr.Blocks(title="Real-ESRGAN 4x Upscaler") as demo: |
| gr.Markdown( |
| """ |
| # π Minimalist 4x AI Upscaler |
| Upload an image to process it through the high-fidelity Real-ESRGAN model. Optimized for strict CPU execution. |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| inp_image = gr.Image(type="pil", label="Original Image") |
| btn_upscale = gr.Button("Upscale 4x", variant="primary") |
| with gr.Column(): |
| out_preview = gr.Image(type="pil", label="4x Output Result") |
|
|
| btn_upscale.click(fn=upscale, inputs=inp_image, outputs=out_preview) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |