Spaces:
Build error
Build error
File size: 6,459 Bytes
821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 3f4b262 821a664 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import cv2
import numpy as np
import os
import urllib.request
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "weights")
MODEL_FILENAME = "realesrgan_x4plus.onnx"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILENAME)
MODEL_URL = (
"https://huggingface.co/Qualcomm/Real-ESRGAN-x4plus/resolve/main/"
"Real-ESRGAN-x4plus.onnx"
)
SCALE_FACTOR = 4
TILE_SIZE = 256 # Process in tiles to limit memory usage
TILE_OVERLAP = 16 # Overlap between tiles for seamless stitching
# Lazy-loaded ONNX session
_session = None
def _ensure_model():
"""Download the Real-ESRGAN ONNX model if it doesn't exist locally."""
if os.path.exists(MODEL_PATH):
return
os.makedirs(MODEL_DIR, exist_ok=True)
print(f"Downloading Real-ESRGAN x4plus model to {MODEL_PATH} ...")
print("(This is a one-time download, ~67 MB)")
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
print("Download complete.")
def _get_session():
"""Lazily initialize the ONNX Runtime inference session."""
global _session
if _session is None:
import onnxruntime as ort
ort.set_default_logger_severity(3) # Suppress verbose logs
_ensure_model()
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
_session = ort.InferenceSession(
MODEL_PATH,
sess_options=opts,
providers=["CPUExecutionProvider"],
)
return _session
def _run_esrgan_tile(session, tile_bgr: np.ndarray) -> np.ndarray:
"""
Run a single BGR tile through the Real-ESRGAN ONNX model.
Input: uint8 BGR HWC β Output: uint8 BGR HWC (4Γ larger)
"""
# BGR β RGB, HWC β CHW, normalise to [0,1]
rgb = cv2.cvtColor(tile_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
tensor = np.expand_dims(rgb.transpose(2, 0, 1), axis=0) # 1Γ3ΓHΓW
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: tensor})[0][0] # 3Γ(4H)Γ(4W)
# CHW β HWC, clip, convert back to BGR uint8
out_rgb = (result.transpose(1, 2, 0) * 255.0).clip(0, 255).astype(np.uint8)
return cv2.cvtColor(out_rgb, cv2.COLOR_RGB2BGR)
def _upscale_tiled(session, img_bgr: np.ndarray) -> np.ndarray:
"""
Upscale a full BGR image using tiled inference with overlap blending.
This prevents OOM on large images while avoiding visible seams.
"""
h, w = img_bgr.shape[:2]
sf = SCALE_FACTOR
# Pad image so dimensions are divisible by tile_size
pad_h = (TILE_SIZE - h % TILE_SIZE) % TILE_SIZE
pad_w = (TILE_SIZE - w % TILE_SIZE) % TILE_SIZE
padded = cv2.copyMakeBorder(img_bgr, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
ph, pw = padded.shape[:2]
# Output canvas
out_h, out_w = ph * sf, pw * sf
output = np.zeros((out_h, out_w, 3), dtype=np.float64)
weight = np.zeros((out_h, out_w, 1), dtype=np.float64)
# Iterate over tiles with overlap
step = TILE_SIZE - TILE_OVERLAP
for y in range(0, ph, step):
for x in range(0, pw, step):
# Clamp tile boundaries
ty = min(y, ph - TILE_SIZE)
tx = min(x, pw - TILE_SIZE)
tile = padded[ty : ty + TILE_SIZE, tx : tx + TILE_SIZE]
# Run inference
upscaled_tile = _run_esrgan_tile(session, tile)
# Output coordinates
oy, ox = ty * sf, tx * sf
th, tw = upscaled_tile.shape[:2]
# Accumulate with simple averaging (overlap regions get averaged)
output[oy : oy + th, ox : ox + tw] += upscaled_tile.astype(np.float64)
weight[oy : oy + th, ox : ox + tw] += 1.0
# Average overlapping regions
weight = np.maximum(weight, 1.0)
output = (output / weight).clip(0, 255).astype(np.uint8)
# Remove padding from output
return output[: h * sf, : w * sf]
def upscale_image(img: np.ndarray) -> np.ndarray:
"""
Upscale an image 4Γ using Real-ESRGAN via ONNX Runtime.
Handles both BGR and BGRA (transparent) images.
Falls back to local Lanczos upscaling if ONNX inference fails.
"""
has_alpha = len(img.shape) == 3 and img.shape[2] == 4
if has_alpha:
bgr = img[:, :, :3]
alpha = img[:, :, 3]
else:
bgr = img
alpha = None
try:
session = _get_session()
upscaled_bgr = _upscale_tiled(session, bgr)
if alpha is not None:
uh, uw = upscaled_bgr.shape[:2]
upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4)
_, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
return cv2.merge((
upscaled_bgr[:, :, 0],
upscaled_bgr[:, :, 1],
upscaled_bgr[:, :, 2],
upscaled_alpha,
))
return upscaled_bgr
except Exception as e:
print(f"Real-ESRGAN upscale failed: {e}")
print("Falling back to local Lanczos upscaling...")
return _local_fallback_upscale(img)
def _local_fallback_upscale(img: np.ndarray) -> np.ndarray:
"""
Fallback: local multi-pass Lanczos + sharpening if ONNX is unavailable.
"""
has_alpha = len(img.shape) == 3 and img.shape[2] == 4
if has_alpha:
bgr = img[:, :, :3]
alpha = img[:, :, 3]
else:
bgr = img
alpha = None
h, w = bgr.shape[:2]
upscaled = cv2.resize(bgr, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
upscaled = cv2.bilateralFilter(upscaled, d=5, sigmaColor=40, sigmaSpace=40)
# Unsharp mask
blurred = cv2.GaussianBlur(upscaled, (0, 0), 2.0)
upscaled = cv2.addWeighted(upscaled, 2.0, blurred, -1.0, 0)
if alpha is not None:
uh, uw = upscaled.shape[:2]
upscaled_alpha = cv2.resize(alpha, (uw, uh), interpolation=cv2.INTER_LANCZOS4)
_, upscaled_alpha = cv2.threshold(upscaled_alpha, 127, 255, cv2.THRESH_BINARY)
return cv2.merge((upscaled[:, :, 0], upscaled[:, :, 1], upscaled[:, :, 2], upscaled_alpha))
return upscaled
|