Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import cv2
|
| 3 |
import numpy
|
| 4 |
import os
|
| 5 |
import random
|
|
|
|
| 6 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 7 |
from basicsr.utils.download_util import load_file_from_url
|
| 8 |
|
|
@@ -20,14 +39,11 @@ img_mode = "RGBA"
|
|
| 20 |
# Utilities
|
| 21 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
def rnd_string(x: int) -> str:
|
| 23 |
-
"""Returns a string of 'x' random characters."""
|
| 24 |
characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
|
| 25 |
-
|
| 26 |
-
return result
|
| 27 |
|
| 28 |
|
| 29 |
def reset():
|
| 30 |
-
"""Resets the Image components and deletes the last processed image."""
|
| 31 |
global last_file
|
| 32 |
if last_file:
|
| 33 |
try:
|
|
@@ -40,10 +56,6 @@ def reset():
|
|
| 40 |
|
| 41 |
|
| 42 |
def has_transparency(img):
|
| 43 |
-
"""
|
| 44 |
-
Check for transparency in a PIL image.
|
| 45 |
-
https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
|
| 46 |
-
"""
|
| 47 |
if img.info.get("transparency", None) is not None:
|
| 48 |
return True
|
| 49 |
if img.mode == "P":
|
|
@@ -59,19 +71,13 @@ def has_transparency(img):
|
|
| 59 |
|
| 60 |
|
| 61 |
def image_properties(img):
|
| 62 |
-
"""Return resolution & color mode of the input image; set global img_mode."""
|
| 63 |
global img_mode
|
| 64 |
if img:
|
| 65 |
-
if has_transparency(img)
|
| 66 |
-
|
| 67 |
-
else:
|
| 68 |
-
img_mode = "RGB"
|
| 69 |
-
properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
|
| 70 |
-
return properties
|
| 71 |
|
| 72 |
|
| 73 |
def model_tip_text(model_name: str) -> str:
|
| 74 |
-
"""Return human-friendly guidance for the chosen model."""
|
| 75 |
tips = {
|
| 76 |
"RealESRGAN_x4plus": (
|
| 77 |
"**RealESRGAN_x4plus (4Γ)** β Best for photoreal images (portraits, landscapes). "
|
|
@@ -101,50 +107,40 @@ def model_tip_text(model_name: str) -> str:
|
|
| 101 |
# Core upscaling
|
| 102 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
| 104 |
-
"""Real-ESRGAN function to restore (and upscale) images with robust defaults."""
|
| 105 |
if img is None:
|
| 106 |
return
|
| 107 |
|
| 108 |
# ----- Select backbone + weights -----
|
| 109 |
-
if model_name == 'RealESRGAN_x4plus':
|
| 110 |
-
model = RRDBNet(
|
| 111 |
-
netscale = 4
|
| 112 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
| 113 |
|
| 114 |
-
elif model_name == 'RealESRNet_x4plus':
|
| 115 |
-
model = RRDBNet(
|
| 116 |
-
netscale = 4
|
| 117 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
|
| 118 |
|
| 119 |
-
elif model_name == 'RealESRGAN_x4plus_anime_6B':
|
| 120 |
-
model = RRDBNet(
|
| 121 |
-
netscale = 4
|
| 122 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
|
| 123 |
|
| 124 |
-
elif model_name == 'RealESRGAN_x2plus':
|
| 125 |
-
model = RRDBNet(
|
| 126 |
-
netscale = 2
|
| 127 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
|
| 128 |
|
| 129 |
-
elif model_name == 'realesr-general-x4v3':
|
| 130 |
-
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
| 131 |
-
netscale = 4
|
| 132 |
-
# We'll ensure BOTH base and WDN weights exist; order matters for DNI.
|
| 133 |
file_url = [
|
| 134 |
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
|
| 135 |
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
|
| 136 |
]
|
| 137 |
-
|
| 138 |
else:
|
| 139 |
raise ValueError(f"Unknown model: {model_name}")
|
| 140 |
|
| 141 |
-
# ----- Ensure weights
|
| 142 |
-
# For the general-x4v3 case we download both; for others single file is fine.
|
| 143 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 144 |
weights_dir = os.path.join(ROOT_DIR, 'weights')
|
| 145 |
os.makedirs(weights_dir, exist_ok=True)
|
| 146 |
|
| 147 |
-
# Track model paths
|
| 148 |
local_paths = []
|
| 149 |
for url in file_url:
|
| 150 |
fname = os.path.basename(url)
|
|
@@ -153,27 +149,22 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
|
| 153 |
local_path = load_file_from_url(url=url, model_dir=weights_dir, progress=True)
|
| 154 |
local_paths.append(local_path)
|
| 155 |
|
| 156 |
-
# Default path(s)
|
| 157 |
if model_name == 'realesr-general-x4v3':
|
| 158 |
-
# Order: [base, wdn] then set DNI weights accordingly
|
| 159 |
base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
|
| 160 |
-
wdn_path
|
| 161 |
model_path = [base_path, wdn_path]
|
| 162 |
denoise_strength = float(denoise_strength)
|
| 163 |
-
|
| 164 |
-
dni_weight = [1.0 - denoise_strength, denoise_strength]
|
| 165 |
else:
|
| 166 |
model_path = os.path.join(weights_dir, f"{model_name}.pth")
|
| 167 |
dni_weight = None
|
| 168 |
|
| 169 |
# ----- CUDA / precision / tiling -----
|
| 170 |
-
# Be defensive: cv2.cuda may not exist in CPU-only builds.
|
| 171 |
use_cuda = False
|
| 172 |
try:
|
| 173 |
use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
|
| 174 |
except Exception:
|
| 175 |
use_cuda = False
|
| 176 |
-
|
| 177 |
gpu_id = 0 if use_cuda else None
|
| 178 |
|
| 179 |
upsampler = RealESRGANer(
|
|
@@ -181,10 +172,10 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
|
| 181 |
model_path=model_path,
|
| 182 |
dni_weight=dni_weight,
|
| 183 |
model=model,
|
| 184 |
-
tile=256,
|
| 185 |
tile_pad=10,
|
| 186 |
pre_pad=10,
|
| 187 |
-
half=bool(use_cuda),
|
| 188 |
gpu_id=gpu_id
|
| 189 |
)
|
| 190 |
|
|
@@ -200,7 +191,7 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
|
| 200 |
bg_upsampler=upsampler
|
| 201 |
)
|
| 202 |
|
| 203 |
-
# -----
|
| 204 |
cv_img = numpy.array(img)
|
| 205 |
if cv_img.ndim == 3 and cv_img.shape[2] == 4:
|
| 206 |
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
|
|
@@ -218,7 +209,7 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
|
| 218 |
print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
|
| 219 |
return None
|
| 220 |
|
| 221 |
-
# ----- cv2 ->
|
| 222 |
if output.ndim == 3 and output.shape[2] == 4:
|
| 223 |
display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
| 224 |
extension = 'png'
|
|
@@ -234,7 +225,7 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
|
| 234 |
except Exception as e:
|
| 235 |
print("Save error:", e)
|
| 236 |
|
| 237 |
-
return display_img
|
| 238 |
|
| 239 |
|
| 240 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -255,7 +246,7 @@ def main():
|
|
| 255 |
"RealESRGAN_x2plus",
|
| 256 |
"realesr-general-x4v3",
|
| 257 |
],
|
| 258 |
-
value="RealESRGAN_x4plus",
|
| 259 |
show_label=True
|
| 260 |
)
|
| 261 |
denoise_strength = gr.Slider(
|
|
@@ -268,7 +259,6 @@ def main():
|
|
| 268 |
)
|
| 269 |
face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
|
| 270 |
|
| 271 |
-
# Model tips panel (auto-updates)
|
| 272 |
model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus"))
|
| 273 |
|
| 274 |
with gr.Row():
|
|
@@ -281,7 +271,6 @@ def main():
|
|
| 281 |
reset_btn = gr.Button("Remove images")
|
| 282 |
restore_btn = gr.Button("Upscale")
|
| 283 |
|
| 284 |
-
# Event listeners:
|
| 285 |
input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
|
| 286 |
model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips)
|
| 287 |
|
|
|
|
| 1 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
# TorchVision compat shim (MUST be before importing basicsr)
|
| 3 |
+
# Fixes: ModuleNotFoundError: torchvision.transforms.functional_tensor
|
| 4 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
+
import sys, types
|
| 6 |
+
try:
|
| 7 |
+
# If old path exists, do nothing
|
| 8 |
+
import torchvision.transforms.functional_tensor as _ft # noqa: F401
|
| 9 |
+
except Exception:
|
| 10 |
+
# Map to the new API location
|
| 11 |
+
from torchvision.transforms import functional as _F
|
| 12 |
+
_mod = types.ModuleType("torchvision.transforms.functional_tensor")
|
| 13 |
+
_mod.rgb_to_grayscale = _F.rgb_to_grayscale
|
| 14 |
+
sys.modules["torchvision.transforms.functional_tensor"] = _mod
|
| 15 |
+
|
| 16 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
# Standard imports
|
| 18 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
import gradio as gr
|
| 20 |
import cv2
|
| 21 |
import numpy
|
| 22 |
import os
|
| 23 |
import random
|
| 24 |
+
|
| 25 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 26 |
from basicsr.utils.download_util import load_file_from_url
|
| 27 |
|
|
|
|
| 39 |
# Utilities
|
| 40 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
def rnd_string(x: int) -> str:
|
|
|
|
| 42 |
characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
|
| 43 |
+
return "".join((random.choice(characters)) for _ in range(x))
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def reset():
|
|
|
|
| 47 |
global last_file
|
| 48 |
if last_file:
|
| 49 |
try:
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
def has_transparency(img):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
if img.info.get("transparency", None) is not None:
|
| 60 |
return True
|
| 61 |
if img.mode == "P":
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def image_properties(img):
|
|
|
|
| 74 |
global img_mode
|
| 75 |
if img:
|
| 76 |
+
img_mode = "RGBA" if has_transparency(img) else "RGB"
|
| 77 |
+
return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def model_tip_text(model_name: str) -> str:
|
|
|
|
| 81 |
tips = {
|
| 82 |
"RealESRGAN_x4plus": (
|
| 83 |
"**RealESRGAN_x4plus (4Γ)** β Best for photoreal images (portraits, landscapes). "
|
|
|
|
| 107 |
# Core upscaling
|
| 108 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
|
|
|
| 110 |
if img is None:
|
| 111 |
return
|
| 112 |
|
| 113 |
# ----- Select backbone + weights -----
|
| 114 |
+
if model_name == 'RealESRGAN_x4plus':
|
| 115 |
+
model = RRDBNet(3, 3, 64, 23, 32, scale=4); netscale = 4
|
|
|
|
| 116 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
| 117 |
|
| 118 |
+
elif model_name == 'RealESRNet_x4plus':
|
| 119 |
+
model = RRDBNet(3, 3, 64, 23, 32, scale=4); netscale = 4
|
|
|
|
| 120 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
|
| 121 |
|
| 122 |
+
elif model_name == 'RealESRGAN_x4plus_anime_6B':
|
| 123 |
+
model = RRDBNet(3, 3, 64, 6, 32, scale=4); netscale = 4
|
|
|
|
| 124 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
|
| 125 |
|
| 126 |
+
elif model_name == 'RealESRGAN_x2plus':
|
| 127 |
+
model = RRDBNet(3, 3, 64, 23, 32, scale=2); netscale = 2
|
|
|
|
| 128 |
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
|
| 129 |
|
| 130 |
+
elif model_name == 'realesr-general-x4v3':
|
| 131 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'); netscale = 4
|
|
|
|
|
|
|
| 132 |
file_url = [
|
| 133 |
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
|
| 134 |
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
|
| 135 |
]
|
|
|
|
| 136 |
else:
|
| 137 |
raise ValueError(f"Unknown model: {model_name}")
|
| 138 |
|
| 139 |
+
# ----- Ensure weights on disk -----
|
|
|
|
| 140 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 141 |
weights_dir = os.path.join(ROOT_DIR, 'weights')
|
| 142 |
os.makedirs(weights_dir, exist_ok=True)
|
| 143 |
|
|
|
|
| 144 |
local_paths = []
|
| 145 |
for url in file_url:
|
| 146 |
fname = os.path.basename(url)
|
|
|
|
| 149 |
local_path = load_file_from_url(url=url, model_dir=weights_dir, progress=True)
|
| 150 |
local_paths.append(local_path)
|
| 151 |
|
|
|
|
| 152 |
if model_name == 'realesr-general-x4v3':
|
|
|
|
| 153 |
base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
|
| 154 |
+
wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth')
|
| 155 |
model_path = [base_path, wdn_path]
|
| 156 |
denoise_strength = float(denoise_strength)
|
| 157 |
+
dni_weight = [1.0 - denoise_strength, denoise_strength] # base, WDN
|
|
|
|
| 158 |
else:
|
| 159 |
model_path = os.path.join(weights_dir, f"{model_name}.pth")
|
| 160 |
dni_weight = None
|
| 161 |
|
| 162 |
# ----- CUDA / precision / tiling -----
|
|
|
|
| 163 |
use_cuda = False
|
| 164 |
try:
|
| 165 |
use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
|
| 166 |
except Exception:
|
| 167 |
use_cuda = False
|
|
|
|
| 168 |
gpu_id = 0 if use_cuda else None
|
| 169 |
|
| 170 |
upsampler = RealESRGANer(
|
|
|
|
| 172 |
model_path=model_path,
|
| 173 |
dni_weight=dni_weight,
|
| 174 |
model=model,
|
| 175 |
+
tile=256,
|
| 176 |
tile_pad=10,
|
| 177 |
pre_pad=10,
|
| 178 |
+
half=bool(use_cuda),
|
| 179 |
gpu_id=gpu_id
|
| 180 |
)
|
| 181 |
|
|
|
|
| 191 |
bg_upsampler=upsampler
|
| 192 |
)
|
| 193 |
|
| 194 |
+
# ----- PIL -> cv2 -----
|
| 195 |
cv_img = numpy.array(img)
|
| 196 |
if cv_img.ndim == 3 and cv_img.shape[2] == 4:
|
| 197 |
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
|
|
|
|
| 209 |
print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
|
| 210 |
return None
|
| 211 |
|
| 212 |
+
# ----- cv2 -> display ndarray, also save -----
|
| 213 |
if output.ndim == 3 and output.shape[2] == 4:
|
| 214 |
display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
| 215 |
extension = 'png'
|
|
|
|
| 225 |
except Exception as e:
|
| 226 |
print("Save error:", e)
|
| 227 |
|
| 228 |
+
return display_img
|
| 229 |
|
| 230 |
|
| 231 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 246 |
"RealESRGAN_x2plus",
|
| 247 |
"realesr-general-x4v3",
|
| 248 |
],
|
| 249 |
+
value="RealESRGAN_x4plus",
|
| 250 |
show_label=True
|
| 251 |
)
|
| 252 |
denoise_strength = gr.Slider(
|
|
|
|
| 259 |
)
|
| 260 |
face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
|
| 261 |
|
|
|
|
| 262 |
model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus"))
|
| 263 |
|
| 264 |
with gr.Row():
|
|
|
|
| 271 |
reset_btn = gr.Button("Remove images")
|
| 272 |
restore_btn = gr.Button("Upscale")
|
| 273 |
|
|
|
|
| 274 |
input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
|
| 275 |
model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips)
|
| 276 |
|