James040's picture
Update app.py
8531227 verified
import os
import math
import numpy as np
import onnxruntime as ort
from PIL import Image
import gradio as gr
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 1) Model Configuration
MODEL_PATH = "Real-ESRGAN-x4plus.onnx"
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 2) Optimized CPU Inference Session
sess_opts = ort.SessionOptions()
sess_opts.intra_op_num_threads = 4
sess_opts.inter_op_num_threads = 4
session = None
def load_model():
global session
if session is None:
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model not found. Please ensure '{MODEL_PATH}' is uploaded.")
session = ort.InferenceSession(MODEL_PATH, sess_options=sess_opts, providers=["CPUExecutionProvider"])
return session
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 3) Auto-Sensing Strict Overlapping Tile Engine
def process_tensor(sess, tensor_np):
input_name = sess.get_inputs()[0].name
patch_nchw = np.transpose(tensor_np, (2, 0, 1))[np.newaxis, ...]
out_nchw = sess.run(None, {input_name: patch_nchw})[0]
out_nchw = np.squeeze(out_nchw, axis=0)
return np.transpose(out_nchw, (1, 2, 0))
def upscale(input_img: Image.Image, progress=gr.Progress()):
if input_img is None:
return None
sess = load_model()
SCALE = 4
# --- AUTO-DETECT STRICT SIZING ---
input_shape = sess.get_inputs()[0].shape
expected_h = input_shape[2]
expected_w = input_shape[3]
# If the model demands a specific size (like 128x128), we lock it in.
is_static = isinstance(expected_h, int) and isinstance(expected_w, int)
if is_static:
BLOCK_H = expected_h
BLOCK_W = expected_w
else:
# If dynamic, we use a safe 432x432 block for 16GB RAM limits
BLOCK_H = 432
BLOCK_W = 432
img_rgb = input_img.convert("RGB")
arr = np.array(img_rgb).astype(np.float32) / 255.0
h_orig, w_orig, _ = arr.shape
pad_size = 16
step_h = BLOCK_H - 2 * pad_size
step_w = BLOCK_W - 2 * pad_size
# 1. Pad the original image so it's perfectly divisible by our step size
pad_h_req = (step_h - (h_orig % step_h)) % step_h
pad_w_req = (step_w - (w_orig % step_w)) % step_w
arr_div = np.pad(arr, ((0, pad_h_req), (0, pad_w_req), (0, 0)), mode="reflect")
# 2. Add the overlap padding to all sides
arr_padded = np.pad(arr_div, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode="reflect")
h_div, w_div, _ = arr_div.shape
out_h_div = h_div * SCALE
out_w_div = w_div * SCALE
out_arr = np.zeros((out_h_div, out_w_div, 3), dtype=np.float32)
tiles_y = h_div // step_h
tiles_x = w_div // step_w
total_tiles = tiles_y * tiles_x
current_tile = 0
for i in range(tiles_y):
for j in range(tiles_x):
current_tile += 1
progress(current_tile / total_tiles, desc=f"Processing strict tile {current_tile}/{total_tiles}...")
# Extract perfect blocks (e.g., exactly 128x128)
y_start = i * step_h
y_end = y_start + BLOCK_H
x_start = j * step_w
x_end = x_start + BLOCK_W
tile = arr_padded[y_start:y_end, x_start:x_end, :]
# Process block
up_tile = process_tensor(sess, tile)
# Mathematically crop out the overlapping pad
crop_start = pad_size * SCALE
crop_end_h = up_tile.shape[0] - (pad_size * SCALE)
crop_end_w = up_tile.shape[1] - (pad_size * SCALE)
valid_up_tile = up_tile[crop_start:crop_end_h, crop_start:crop_end_w, :]
# Place perfectly into output
out_y_start = i * step_h * SCALE
out_y_end = out_y_start + valid_up_tile.shape[0]
out_x_start = j * step_w * SCALE
out_x_end = out_x_start + valid_up_tile.shape[1]
out_arr[out_y_start:out_y_end, out_x_start:out_x_end, :] = valid_up_tile
# Finally, shave off the divisibility padding we added in step 1
final_out = out_arr[0:h_orig * SCALE, 0:w_orig * SCALE, :]
progress(0.95, desc="Finalizing output...")
final_out = np.clip(final_out, 0.0, 1.0)
return Image.fromarray((final_out * 255.0).round().astype(np.uint8))
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 4) Minimalist UI Setup
with gr.Blocks(title="Real-ESRGAN 4x Upscaler") as demo:
gr.Markdown(
"""
# πŸ” Minimalist 4x AI Upscaler
Upload an image to process it through the high-fidelity Real-ESRGAN model. Optimized for strict CPU execution.
"""
)
with gr.Row():
with gr.Column():
inp_image = gr.Image(type="pil", label="Original Image")
btn_upscale = gr.Button("Upscale 4x", variant="primary")
with gr.Column():
out_preview = gr.Image(type="pil", label="4x Output Result")
btn_upscale.click(fn=upscale, inputs=inp_image, outputs=out_preview)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)