# app.py — Mjölnir · Upscale Images (ZeroGPU-ready, batch click, with logo)
# ---- TorchVision shim (so basicsr can import even without torchvision) ----
import sys, types
try:
import torchvision.transforms.functional_tensor as _ft # noqa: F401
except Exception:
import torch
_mod = types.ModuleType("torchvision.transforms.functional_tensor")
def rgb_to_grayscale(img: "torch.Tensor", num_output_channels: int = 1) -> "torch.Tensor":
if not torch.is_tensor(img):
raise TypeError("rgb_to_grayscale expects a torch.Tensor")
if img.ndim < 3 or img.shape[-3] != 3:
raise ValueError(f"expected tensor with C=3 as the third-from-last dim, got shape {tuple(img.shape)}")
r, g, b = img[..., -3, :, :], img[..., -2, :, :], img[..., -1, :, :]
gray = 0.2989*r + 0.5870*g + 0.1140*b
return torch.stack([gray, gray, gray], dim=-3) if num_output_channels == 3 else gray.unsqueeze(-3)
_mod.rgb_to_grayscale = rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = _mod
# --------------------------------------------------------------------------
import os, time, zipfile, tempfile, shutil, base64
from pathlib import Path
from typing import List, Optional, Tuple
import gradio as gr
import numpy as np
import cv2
from PIL import Image
# ZeroGPU hook
import spaces
# Real-ESRGAN / basicsr (importing these at top is OK; just don't touch CUDA here)
from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
# ────────────────────────────────────────────────────────
# Branding (logo)
# ────────────────────────────────────────────────────────
def try_load_logo_b64() -> str:
try:
with open("bifrost_logo.png", "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
except Exception:
return ""
LOGO_B64 = try_load_logo_b64()
def render_logo_html(px: int = 96) -> str:
img = f'
' if LOGO_B64 else ""
return f"""
{img}
Mjölnir · Image Upscaler
The Hammer of Clarity — upscale images into sharper, powerful detail.
"""
# ────────────────────────────────────────────────────────
# Small helpers
# ────────────────────────────────────────────────────────
import re
_num = re.compile(r'(\d+)')
def _natural_key(p: Path | str):
s = str(p)
return [int(t) if t.isdigit() else t.lower() for t in _num.split(s)]
def sample_paths(paths: List[Path] | List[str], n: int = 30) -> List[str]:
if not paths: return []
paths = sorted(paths, key=_natural_key)
total = len(paths); n = max(1, min(n, total))
if n == total: return [str(p) for p in paths]
step = (total - 1) / (n - 1); idxs = [round(i * step) for i in range(n)]
out, seen = [], set()
for i in idxs:
if int(i) not in seen:
out.append(str(paths[int(i)])); seen.add(int(i))
return out
def render_progress(pct: float, label: str = "") -> str:
pct = max(0.0, min(100.0, pct))
return f'''
{label} {pct:.1f}%
'''
def _ensure_dir(p: Path) -> Path:
p.mkdir(parents=True, exist_ok=True); return p
def _save_zip_of_dir(dir_path: Path, zip_path: Path) -> str:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for p in sorted(dir_path.glob("*.*"), key=_natural_key):
if p.suffix.lower() in [".jpg", ".jpeg", ".png"]:
zf.write(p, p.name)
return str(zip_path)
def _list_image_paths_from_upload(files: List[gr.File] | None) -> List[str]:
if not files: return []
return [str(Path(f.name)) for f in files if Path(f.name).suffix.lower() in [".jpg",".jpeg",".png"]]
def _build_gallery_from_dir(dir_path: Path, n: int = 30) -> List[str]:
paths = sorted(list(dir_path.glob("*.jpg")) + list(dir_path.glob("*.png")), key=_natural_key)
return sample_paths(paths, n)
# ────────────────────────────────────────────────────────
# Models & weight management (CPU-safe; no CUDA used here)
# ────────────────────────────────────────────────────────
def build_rrdb(scale: int, num_block: int):
return _RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=scale)
def _weights_dir() -> str:
wdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights")
os.makedirs(wdir, exist_ok=True)
return wdir
def ensure_weights(model_id: str) -> Tuple[object, int, str, Optional[list]]:
"""
Returns: (model, netscale, model_path_or_list, dni_weight_placeholder)
"""
wdir = _weights_dir()
if model_id == "x4plus":
model = build_rrdb(scale=4, num_block=23); netscale = 4
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_path = os.path.join(wdir, "RealESRGAN_x4plus.pth")
if not os.path.isfile(model_path):
load_file_from_url(url=url, model_dir=wdir, progress=True)
return model, netscale, model_path, None
if model_id == "x4plus-anime":
model = build_rrdb(scale=4, num_block=6); netscale = 4
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
model_path = os.path.join(wdir, "RealESRGAN_x4plus_anime_6B.pth")
if not os.path.isfile(model_path):
load_file_from_url(url=url, model_dir=wdir, progress=True)
return model, netscale, model_path, None
if model_id == "x2plus":
model = build_rrdb(scale=2, num_block=23); netscale = 2
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
model_path = os.path.join(wdir, "RealESRGAN_x2plus.pth")
if not os.path.isfile(model_path):
load_file_from_url(url=url, model_dir=wdir, progress=True)
return model, netscale, model_path, None
# For UI compatibility only: map general-x4v3 to x4plus backend (or implement SRVGG if you prefer)
if model_id == "general-x4v3":
# Proper SRVGG model:
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
base_pth = os.path.join(wdir, "realesr-general-x4v3.pth")
if not os.path.isfile(base_pth):
load_file_from_url(url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
model_dir=wdir, progress=True)
return model, netscale, base_pth, None
raise ValueError(f"Unknown model_id: {model_id}")
def map_ui_model_to_internal(ui_name: str) -> str:
return {
"RealESRGAN_x4plus": "x4plus",
"RealESRGAN_x4plus_anime_6B": "x4plus-anime",
"RealESRGAN_x2plus": "x2plus",
"RealESRNet_x4plus": "x4plus", # fallback
"realesr-general-x4v3": "general-x4v3", # SRVGG
}.get(ui_name, "x4plus")
def clamp_scale_for_model(outscale: int, model_id: str) -> int:
return 2 if model_id == "x2plus" else 4
# ────────────────────────────────────────────────────────
# Step 2 · Prepare sources (CPU)
# ────────────────────────────────────────────────────────
def step2_prepare_sources(frames_list, uploaded_imgs, max_images):
src = _list_image_paths_from_upload(uploaded_imgs) or (frames_list or [])
if not src:
return [], "", 0, 0, "No images found. Upload files first.", render_progress(0.0, "Idle")
try:
max_images = int(max_images or 0)
except Exception:
max_images = 0
if max_images > 0:
src = src[:max_images]
work = Path(tempfile.mkdtemp(prefix="up_manual_"))
out_dir = _ensure_dir(work / "upscaled")
total = len(src); done_idx = 0
return src, str(out_dir), done_idx, total, f"Sources loaded: {total} image(s). Click 'Process Next Batch'.", render_progress(0.0, "Ready")
# ────────────────────────────────────────────────────────
# Step 2 · Process next batch (GPU) — ZeroGPU entry point
# ────────────────────────────────────────────────────────
@spaces.GPU
def step2_process_next_batch(
up_src_paths, up_out_dir, up_done_idx, up_total,
ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size,
):
"""
Runs on ZeroGPU. Heavy parts (model load + enhance) are done inside this function.
Yields progress after each image in the current batch.
"""
# Validate inputs
if not up_src_paths or not up_out_dir:
yield None, None, "Load sources first.", render_progress(0.0, "Idle"), up_done_idx, up_out_dir
return
# Resolve model & scale
model_id = map_ui_model_to_internal(ui_model_name)
scale = clamp_scale_for_model(int(outscale or 4), model_id)
tile = int(tile or 256)
batch_size = max(1, int(batch_size or 8))
use_half = (precision == "half") # we'll honor this on CUDA only
# Ensure weights & build model (still CPU-safe) then instantiate ESRGANer on GPU
model, netscale, model_path, dni_weight = ensure_weights(model_id)
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=10,
pre_pad=10,
half=use_half, # when ZeroGPU gives CUDA, this enables fp16
gpu_id=0 # request the single available GPU
)
# Optional face enhancer (kept off by default as it adds weight download)
face_enhancer = None
if face_enhance:
try:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
upscale=scale, arch="clean", channel_multiplier=2, bg_upsampler=upsampler
)
except Exception as e:
print("GFPGAN load failed:", e)
face_enhancer = None
# Batch window
start = int(up_done_idx or 0)
end = min(start + batch_size, int(up_total or 0))
out_dir = Path(up_out_dir)
if start >= up_total:
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, "All images processed.", render_progress(100.0, "Done"), start, up_out_dir
return
batch_paths = up_src_paths[start:end]
total_in_batch = len(batch_paths)
t0 = time.time()
for idx, fp in enumerate(batch_paths, start=1):
try:
with Image.open(fp) as im:
img = im.convert("RGB")
cv_img = np.array(img)
if face_enhancer:
_, _, output = face_enhancer.enhance(
cv_img, has_aligned=False, only_center_face=False, paste_back=True
)
else:
# denoise_strength has effect with general-x4v3; harmless otherwise
output, _ = upsampler.enhance(cv_img, outscale=scale, denoise_strength=float(denoise_strength or 0.5))
Image.fromarray(output).save(out_dir / (Path(fp).stem + ".jpg"), quality=95)
except Exception as e:
print("Upscale error:", e)
# Progress for THIS batch
elapsed = time.time() - t0
pct_batch = (idx / total_in_batch) * 100.0
eta = (total_in_batch - idx) * (elapsed / max(1, idx))
label = (f"Batch: {idx}/{total_in_batch} · ~{eta:.1f}s ETA · "
f"global {start+idx}/{up_total} (x{scale}, model={ui_model_name}, tile={tile}, half={use_half})")
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, label, render_progress(pct_batch, f"Upscaling {pct_batch:.0f}% (batch)"), start+idx, up_out_dir
# Batch complete
next_idx = end
pct_global = (next_idx / up_total) * 100.0 if up_total else 100.0
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, f"Processed batch of {total_in_batch}. {next_idx}/{up_total} done.", render_progress(pct_global, "Upscaling… (global)"), next_idx, up_out_dir
# ────────────────────────────────────────────────────────
# UI
# ────────────────────────────────────────────────────────
def build_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(render_logo_html(88))
gr.Markdown("Upload images and upscale with Real-ESRGAN.")
# States
frames_state = gr.State([]) # present for parity; not used here
up_src_paths_state = gr.State([])
up_out_dir_state = gr.State("")
up_done_idx_state = gr.State(0)
up_total_state = gr.State(0)
imgs_override = gr.Files(label="Upload images (JPG/PNG)", file_types=[".jpg",".jpeg",".png"], type="filepath")
with gr.Accordion("Upscaling options", open=True):
with gr.Row():
ui_model_name = gr.Dropdown(
label="Upscaler model",
choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-general-x4v3"],
value="RealESRGAN_x4plus"
)
denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise (only general-x4v3)")
outscale = gr.Slider(1, 6, value=4, step=1, label="Resolution upscale")
face_enhance = gr.Checkbox(value=False, label="Face Enhancement (GFPGAN)")
with gr.Row():
tile = gr.Number(value=256, label="Tile size (try 128 if OOM; 0=auto)")
precision = gr.Dropdown(["auto", "half", "full"], value="half", label="Precision (GPU: half, CPU ignored)")
with gr.Row():
batch_size = gr.Number(value=12, precision=0, label="Batch size per click")
max_images = gr.Number(value=0, precision=0, label="Max images to process (0 = all)")
with gr.Row():
btn_prepare = gr.Button("Load / Reset Sources", variant="secondary")
btn_next = gr.Button("Process Next Batch (uses GPU)", variant="primary")
prog = gr.HTML(render_progress(0.0, "Idle"))
gallery_up = gr.Gallery(label="Upscaled preview (sampled 30)", columns=6, height=480)
zip_up = gr.File(label="Download upscaled ZIP")
details = gr.Markdown("")
# 1) load/reset sources (CPU)
btn_prepare.click(
step2_prepare_sources,
inputs=[frames_state, imgs_override, max_images],
outputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state, details, prog]
)
# 2) process one batch per click (GPU)
btn_next.click(
step2_process_next_batch,
inputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state,
ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size],
outputs=[gallery_up, zip_up, details, prog, up_done_idx_state, up_out_dir_state]
)
gr.Markdown(
"> ℹ️ **ZeroGPU tips**: Larger tiles are faster but use more VRAM. If you hit OOM, try `tile=128`, "
"`batch size=4–8`, and keep `Precision=half`."
)
return demo
if __name__ == "__main__":
build_ui().queue().launch()