File size: 17,406 Bytes
9bf416a 0385957 9bf416a 01d8863 7477470 01d8863 7477470 01d8863 7477470 9bf416a 7477470 01d8863 9bf416a 01d8863 7477470 9bf416a 7477470 39c5020 01d8863 9bf416a 01d8863 7477470 fb8d4c6 163b038 7477470 9bf416a 01d8863 7477470 0385957 7477470 9bf416a 7477470 0385957 7477470 9bf416a 01d8863 9bf416a 01d8863 7477470 01d8863 7477470 9bf416a 01d8863 9bf416a 01d8863 7477470 9bf416a 2ee5af8 7477470 01d8863 9bf416a 01d8863 9bf416a 2ee5af8 7477470 01d8863 9bf416a 01d8863 9bf416a 2ee5af8 7477470 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 0385957 9bf416a 01d8863 9bf416a 01d8863 9bf416a 7477470 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 7477470 9bf416a 01d8863 7477470 01d8863 901db54 01d8863 7477470 18be508 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 01d8863 9bf416a 7477470 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 |
# app.py β MjΓΆlnir Β· Upscale Images (ZeroGPU-ready, batch click, with logo)
# ---- TorchVision shim (so basicsr can import even without torchvision) ----
import sys, types
try:
import torchvision.transforms.functional_tensor as _ft # noqa: F401
except Exception:
import torch
_mod = types.ModuleType("torchvision.transforms.functional_tensor")
def rgb_to_grayscale(img: "torch.Tensor", num_output_channels: int = 1) -> "torch.Tensor":
if not torch.is_tensor(img):
raise TypeError("rgb_to_grayscale expects a torch.Tensor")
if img.ndim < 3 or img.shape[-3] != 3:
raise ValueError(f"expected tensor with C=3 as the third-from-last dim, got shape {tuple(img.shape)}")
r, g, b = img[..., -3, :, :], img[..., -2, :, :], img[..., -1, :, :]
gray = 0.2989*r + 0.5870*g + 0.1140*b
return torch.stack([gray, gray, gray], dim=-3) if num_output_channels == 3 else gray.unsqueeze(-3)
_mod.rgb_to_grayscale = rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = _mod
# --------------------------------------------------------------------------
import os, time, zipfile, tempfile, shutil, base64
from pathlib import Path
from typing import List, Optional, Tuple
import gradio as gr
import numpy as np
import cv2
from PIL import Image
# ZeroGPU hook
import spaces
# Real-ESRGAN / basicsr (importing these at top is OK; just don't touch CUDA here)
from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Branding (logo)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def try_load_logo_b64() -> str:
try:
with open("bifrost_logo.png", "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
except Exception:
return ""
LOGO_B64 = try_load_logo_b64()
def render_logo_html(px: int = 96) -> str:
img = f'<img src="data:image/png;base64,{LOGO_B64}" style="height:{px}px;width:auto;" />' if LOGO_B64 else ""
return f"""
<div style="display:flex;align-items:center;gap:16px;">
{img}
<div>
<div style="font-size:1.6rem;font-weight:800;">MjΓΆlnir Β· Image Upscaler</div>
<div style="opacity:0.8;">The Hammer of Clarity β upscale images into sharper, powerful detail.</div>
</div>
</div>
<hr>
"""
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Small helpers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import re
_num = re.compile(r'(\d+)')
def _natural_key(p: Path | str):
s = str(p)
return [int(t) if t.isdigit() else t.lower() for t in _num.split(s)]
def sample_paths(paths: List[Path] | List[str], n: int = 30) -> List[str]:
if not paths: return []
paths = sorted(paths, key=_natural_key)
total = len(paths); n = max(1, min(n, total))
if n == total: return [str(p) for p in paths]
step = (total - 1) / (n - 1); idxs = [round(i * step) for i in range(n)]
out, seen = [], set()
for i in idxs:
if int(i) not in seen:
out.append(str(paths[int(i)])); seen.add(int(i))
return out
def render_progress(pct: float, label: str = "") -> str:
pct = max(0.0, min(100.0, pct))
return f'''<div style="width:100%;border:1px solid #ddd;border-radius:8px;overflow:hidden;height:18px;">
<div style="height:100%;width:{pct:.1f}%;background:#3b82f6;"></div></div>
<div style="font-size:12px;opacity:.8;margin-top:4px;">{label} {pct:.1f}%</div>'''
def _ensure_dir(p: Path) -> Path:
p.mkdir(parents=True, exist_ok=True); return p
def _save_zip_of_dir(dir_path: Path, zip_path: Path) -> str:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for p in sorted(dir_path.glob("*.*"), key=_natural_key):
if p.suffix.lower() in [".jpg", ".jpeg", ".png"]:
zf.write(p, p.name)
return str(zip_path)
def _list_image_paths_from_upload(files: List[gr.File] | None) -> List[str]:
if not files: return []
return [str(Path(f.name)) for f in files if Path(f.name).suffix.lower() in [".jpg",".jpeg",".png"]]
def _build_gallery_from_dir(dir_path: Path, n: int = 30) -> List[str]:
paths = sorted(list(dir_path.glob("*.jpg")) + list(dir_path.glob("*.png")), key=_natural_key)
return sample_paths(paths, n)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Models & weight management (CPU-safe; no CUDA used here)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_rrdb(scale: int, num_block: int):
return _RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=scale)
def _weights_dir() -> str:
wdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights")
os.makedirs(wdir, exist_ok=True)
return wdir
def ensure_weights(model_id: str) -> Tuple[object, int, str, Optional[list]]:
"""
Returns: (model, netscale, model_path_or_list, dni_weight_placeholder)
"""
wdir = _weights_dir()
if model_id == "x4plus":
model = build_rrdb(scale=4, num_block=23); netscale = 4
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_path = os.path.join(wdir, "RealESRGAN_x4plus.pth")
if not os.path.isfile(model_path):
load_file_from_url(url=url, model_dir=wdir, progress=True)
return model, netscale, model_path, None
if model_id == "x4plus-anime":
model = build_rrdb(scale=4, num_block=6); netscale = 4
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
model_path = os.path.join(wdir, "RealESRGAN_x4plus_anime_6B.pth")
if not os.path.isfile(model_path):
load_file_from_url(url=url, model_dir=wdir, progress=True)
return model, netscale, model_path, None
if model_id == "x2plus":
model = build_rrdb(scale=2, num_block=23); netscale = 2
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
model_path = os.path.join(wdir, "RealESRGAN_x2plus.pth")
if not os.path.isfile(model_path):
load_file_from_url(url=url, model_dir=wdir, progress=True)
return model, netscale, model_path, None
# For UI compatibility only: map general-x4v3 to x4plus backend (or implement SRVGG if you prefer)
if model_id == "general-x4v3":
# Proper SRVGG model:
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
base_pth = os.path.join(wdir, "realesr-general-x4v3.pth")
if not os.path.isfile(base_pth):
load_file_from_url(url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
model_dir=wdir, progress=True)
return model, netscale, base_pth, None
raise ValueError(f"Unknown model_id: {model_id}")
def map_ui_model_to_internal(ui_name: str) -> str:
return {
"RealESRGAN_x4plus": "x4plus",
"RealESRGAN_x4plus_anime_6B": "x4plus-anime",
"RealESRGAN_x2plus": "x2plus",
"RealESRNet_x4plus": "x4plus", # fallback
"realesr-general-x4v3": "general-x4v3", # SRVGG
}.get(ui_name, "x4plus")
def clamp_scale_for_model(outscale: int, model_id: str) -> int:
return 2 if model_id == "x2plus" else 4
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Step 2 Β· Prepare sources (CPU)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def step2_prepare_sources(frames_list, uploaded_imgs, max_images):
src = _list_image_paths_from_upload(uploaded_imgs) or (frames_list or [])
if not src:
return [], "", 0, 0, "No images found. Upload files first.", render_progress(0.0, "Idle")
try:
max_images = int(max_images or 0)
except Exception:
max_images = 0
if max_images > 0:
src = src[:max_images]
work = Path(tempfile.mkdtemp(prefix="up_manual_"))
out_dir = _ensure_dir(work / "upscaled")
total = len(src); done_idx = 0
return src, str(out_dir), done_idx, total, f"Sources loaded: {total} image(s). Click 'Process Next Batch'.", render_progress(0.0, "Ready")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Step 2 Β· Process next batch (GPU) β ZeroGPU entry point
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@spaces.GPU
def step2_process_next_batch(
up_src_paths, up_out_dir, up_done_idx, up_total,
ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size,
):
"""
Runs on ZeroGPU. Heavy parts (model load + enhance) are done inside this function.
Yields progress after each image in the current batch.
"""
# Validate inputs
if not up_src_paths or not up_out_dir:
yield None, None, "Load sources first.", render_progress(0.0, "Idle"), up_done_idx, up_out_dir
return
# Resolve model & scale
model_id = map_ui_model_to_internal(ui_model_name)
scale = clamp_scale_for_model(int(outscale or 4), model_id)
tile = int(tile or 256)
batch_size = max(1, int(batch_size or 8))
use_half = (precision == "half") # we'll honor this on CUDA only
# Ensure weights & build model (still CPU-safe) then instantiate ESRGANer on GPU
model, netscale, model_path, dni_weight = ensure_weights(model_id)
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=10,
pre_pad=10,
half=use_half, # when ZeroGPU gives CUDA, this enables fp16
gpu_id=0 # request the single available GPU
)
# Optional face enhancer (kept off by default as it adds weight download)
face_enhancer = None
if face_enhance:
try:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
upscale=scale, arch="clean", channel_multiplier=2, bg_upsampler=upsampler
)
except Exception as e:
print("GFPGAN load failed:", e)
face_enhancer = None
# Batch window
start = int(up_done_idx or 0)
end = min(start + batch_size, int(up_total or 0))
out_dir = Path(up_out_dir)
if start >= up_total:
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, "All images processed.", render_progress(100.0, "Done"), start, up_out_dir
return
batch_paths = up_src_paths[start:end]
total_in_batch = len(batch_paths)
t0 = time.time()
for idx, fp in enumerate(batch_paths, start=1):
try:
with Image.open(fp) as im:
img = im.convert("RGB")
cv_img = np.array(img)
if face_enhancer:
_, _, output = face_enhancer.enhance(
cv_img, has_aligned=False, only_center_face=False, paste_back=True
)
else:
# denoise_strength has effect with general-x4v3; harmless otherwise
output, _ = upsampler.enhance(cv_img, outscale=scale, denoise_strength=float(denoise_strength or 0.5))
Image.fromarray(output).save(out_dir / (Path(fp).stem + ".jpg"), quality=95)
except Exception as e:
print("Upscale error:", e)
# Progress for THIS batch
elapsed = time.time() - t0
pct_batch = (idx / total_in_batch) * 100.0
eta = (total_in_batch - idx) * (elapsed / max(1, idx))
label = (f"Batch: {idx}/{total_in_batch} Β· ~{eta:.1f}s ETA Β· "
f"global {start+idx}/{up_total} (x{scale}, model={ui_model_name}, tile={tile}, half={use_half})")
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, label, render_progress(pct_batch, f"Upscaling {pct_batch:.0f}% (batch)"), start+idx, up_out_dir
# Batch complete
next_idx = end
pct_global = (next_idx / up_total) * 100.0 if up_total else 100.0
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, f"Processed batch of {total_in_batch}. {next_idx}/{up_total} done.", render_progress(pct_global, "Upscaling⦠(global)"), next_idx, up_out_dir
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# UI
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(render_logo_html(88))
gr.Markdown("Upload images and upscale with Real-ESRGAN.")
# States
frames_state = gr.State([]) # present for parity; not used here
up_src_paths_state = gr.State([])
up_out_dir_state = gr.State("")
up_done_idx_state = gr.State(0)
up_total_state = gr.State(0)
imgs_override = gr.Files(label="Upload images (JPG/PNG)", file_types=[".jpg",".jpeg",".png"], type="filepath")
with gr.Accordion("Upscaling options", open=True):
with gr.Row():
ui_model_name = gr.Dropdown(
label="Upscaler model",
choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-general-x4v3"],
value="RealESRGAN_x4plus"
)
denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise (only general-x4v3)")
outscale = gr.Slider(1, 6, value=4, step=1, label="Resolution upscale")
face_enhance = gr.Checkbox(value=False, label="Face Enhancement (GFPGAN)")
with gr.Row():
tile = gr.Number(value=256, label="Tile size (try 128 if OOM; 0=auto)")
precision = gr.Dropdown(["auto", "half", "full"], value="half", label="Precision (GPU: half, CPU ignored)")
with gr.Row():
batch_size = gr.Number(value=12, precision=0, label="Batch size per click")
max_images = gr.Number(value=0, precision=0, label="Max images to process (0 = all)")
with gr.Row():
btn_prepare = gr.Button("Load / Reset Sources", variant="secondary")
btn_next = gr.Button("Process Next Batch (uses GPU)", variant="primary")
prog = gr.HTML(render_progress(0.0, "Idle"))
gallery_up = gr.Gallery(label="Upscaled preview (sampled 30)", columns=6, height=480)
zip_up = gr.File(label="Download upscaled ZIP")
details = gr.Markdown("")
# 1) load/reset sources (CPU)
btn_prepare.click(
step2_prepare_sources,
inputs=[frames_state, imgs_override, max_images],
outputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state, details, prog]
)
# 2) process one batch per click (GPU)
btn_next.click(
step2_process_next_batch,
inputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state,
ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size],
outputs=[gallery_up, zip_up, details, prog, up_done_idx_state, up_out_dir_state]
)
gr.Markdown(
"> βΉοΈ **ZeroGPU tips**: Larger tiles are faster but use more VRAM. If you hit OOM, try `tile=128`, "
"`batch size=4β8`, and keep `Precision=half`."
)
return demo
if __name__ == "__main__":
build_ui().queue().launch()
|