File size: 7,143 Bytes
8f62d7c
e10ba58
ffc6340
8f62d7c
38fcfdd
 
 
 
e10ba58
38fcfdd
e10ba58
 
 
 
6d69bf6
e10ba58
 
 
 
38fcfdd
e10ba58
 
38fcfdd
 
 
 
8f62d7c
ec9864f
38fcfdd
ec9864f
38fcfdd
 
 
 
 
 
 
 
 
 
8f62d7c
ae67fb9
 
ddb8475
 
99e790f
ae67fb9
295b755
99e790f
747eabc
 
1a91866
747eabc
 
 
 
f36c7b4
1a91866
747eabc
295b755
99e790f
 
 
 
ddb8475
ae67fb9
99e790f
 
 
ddb8475
99e790f
 
 
 
 
 
 
 
 
 
ae67fb9
ddb8475
 
ae67fb9
99e790f
 
ddb8475
ae67fb9
99e790f
ae67fb9
295b755
 
ae67fb9
 
 
 
 
99e790f
747eabc
24555d8
747eabc
 
ddb8475
747eabc
ddb8475
295b755
747eabc
 
ddb8475
747eabc
 
 
ddb8475
747eabc
 
 
 
c939fa5
747eabc
 
c939fa5
295b755
747eabc
 
 
 
 
295b755
747eabc
 
 
 
 
 
 
 
295b755
747eabc
 
 
 
 
 
 
 
 
 
 
 
ae67fb9
747eabc
ae67fb9
 
ddb8475
24555d8
 
 
 
ae67fb9
 
ddb8475
 
295b755
ae67fb9
 
 
 
99e790f
 
747eabc
99e790f
 
747eabc
ae67fb9
 
 
 
 
 
99e790f
ae67fb9
 
 
 
806c2c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import sys
import importlib
import os

# =================================================================
# CRITICAL PYTHON 3.13 / GRADIO SDK CONFLICT MONKEY-PATCHES
# =================================================================

# 1. Safely inject HfFolder directly into the real huggingface_hub without blocking other imports
try:
    real_hf_hub = importlib.import_module('huggingface_hub')
    if not hasattr(real_hf_hub, 'HfFolder'):
        class MockHfFolder:
            @classmethod
            def get_token(cls): return os.environ.get("HF_TOKEN")
            @classmethod
            def save_token(cls, token): pass
            @classmethod
            def delete_token(cls): pass

        real_hf_hub.HfFolder = MockHfFolder
        sys.modules['huggingface_hub'].HfFolder = MockHfFolder
except Exception as patch_err_1:
    print(f"Pre-import HfFolder patch skipped or failed: {patch_err_1}")

# 2. Patch the Gradio client schema serialization parser to avoid the 'bool' type loop error
try:
    import gradio_client.utils as client_utils
    orig_get_type = getattr(client_utils, 'get_type', None)
    
    if orig_get_type:
        def safe_get_type(schema):
            if isinstance(schema, bool):
                return "boolean"
            return orig_get_type(schema)
        client_utils.get_type = safe_get_type
        print("Successfully patched gradio_client schema serialization engine.")
except Exception as patch_err_2:
    print(f"Schema serialization engine patch deferred: {patch_err_2}")
# =================================================================

import cv2
import gradio as gr
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download

# Configuration with separate tile size constraints matching the ONNX compilation targets
MODELS = {
    "RealESRGAN_x2plus (Faster 2x)": {
        "repo_id": "tidus2102/Real-ESRGAN",
        "filename": "Real-ESRGAN_x2plus.onnx",
        "scale": 2,
        "tile_size": 64
    },
    "RealESRGAN_x4plus (High Quality 4x)": {
        "repo_id": "KingPro100/real-esrgan-onxx",
        "filename": "Real-開ESRGAN-x4plus.onnx" if False else "Real-ESRGAN-x4plus.onnx",
        "scale": 4,
        "tile_size": 128
    }
}

current_model_name = None
ort_session = None

def load_model(model_choice):
    global current_model_name, ort_session
    if current_model_name == model_choice and ort_session is not None:
        return ort_session

    cfg = MODELS[model_choice]
    print(f"Loading weights for {model_choice}...")
    token = os.environ.get("HF_TOKEN")
    
    model_path = hf_hub_download(
        repo_id=cfg["repo_id"], 
        filename=cfg["filename"],
        token=token
    )

    session_options = ort.SessionOptions()
    session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    
    ort_session = ort.InferenceSession(model_path, session_options, providers=['CPUExecutionProvider'])
    current_model_name = model_choice
    return ort_session

def upscale_image(input_img, model_choice):
    """
    Tiled inference pipeline. Splits arbitrary sized images into model-specific blocks
    (64x64 or 128x128), runs inference on each tile, and puts them back together safely.
    """
    if input_img is None:
        return None
        
    try:
        session = load_model(model_choice)
        cfg = MODELS[model_choice]
        
        scale = cfg["scale"]
        tile = cfg["tile_size"]
        
        h, w, c = input_img.shape
        
        # Calculate how much padding we need to make the image divisible by the tile size
        pad_h = (tile - (h % tile)) % tile
        pad_w = (tile - (w % tile)) % tile
        
        # Pad the image with edge mirroring to avoid border artifacts
        padded_img = cv2.copyMakeBorder(input_img, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT)
        ph, pw, _ = padded_img.shape
        
        # Allocate output canvas
        output_h = ph * scale
        output_w = pw * scale
        output_padded = np.zeros((output_h, output_w, c), dtype=np.uint8)
        
        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name
        
        # Iterate over the rows and columns of the specific tile size blocks
        for y in range(0, ph, tile):
            for x in range(0, pw, tile):
                # Crop tile
                tile_img = padded_img[y:y+tile, x:x+tile]
                
                # Preprocess tile dynamically to match target shape [1, 3, tile_size, tile_size]
                tile_float = tile_img.astype(np.float32) / 255.0
                tile_chw = np.transpose(tile_float, (2, 0, 1))
                tile_batch = np.expand_dims(tile_chw, axis=0)
                
                # Execute ONNX forward tensor calculations on CPU
                ort_outs = session.run([output_name], {input_name: tile_batch})
                output_tensor = ort_outs[0]
                
                # Postprocess tile back to target scale shape
                output_tensor = np.squeeze(output_tensor, axis=0)
                output_tensor = np.clip(output_tensor, 0.0, 1.0)
                output_hwc = np.transpose(output_tensor, (1, 2, 0))
                output_tile = (output_hwc * 255.0).astype(np.uint8)
                
                # Insert the processed super-resolution tile into output matrix
                output_padded[y*scale:(y+tile)*scale, x*scale:(x+tile)*scale] = output_tile
                
        # Unpad the final image back to original proportional dimensions multiplied by scale factor
        final_h = h * scale
        final_w = w * scale
        final_output = output_padded[0:final_h, 0:final_w]
        
        return final_output
        
    except Exception as e:
        print(f"Error executing ONNX runtime tensor transformation graph: {str(e)}")
        blank_err_img = np.zeros((300, 500, 3), dtype=np.uint8)
        cv2.putText(blank_err_img, f"Execution Error: {str(e)[:40]}", (20, 150), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
        return blank_err_img

# Define the user interface layout
with gr.Blocks(title="AI Lightweight Image Upscaler (ONNX)") as demo:
    gr.Markdown("# 🖼️ AI Image Resizer & Upscaler (ONNX Engine)")
    gr.Markdown("Running locally on Hugging Face Free CPU hardware using dynamic tiling maps.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Source Image", type="numpy")
            model_dropdown = gr.Dropdown(
                choices=list(MODELS.keys()),
                value="RealESRGAN_x2plus (Faster 2x)",
                label="Select AI Upscaling Engine"
            )
            submit_btn = gr.Button("Upscale Image", variant="primary")
            
        with gr.Column():
            output_image = gr.Image(label="Enhanced Super-Resolution Result", type="numpy")

    submit_btn.click(
        fn=upscale_image,
        inputs=[input_image, model_dropdown],
        outputs=output_image
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)