Mjolnir / app.py
JS6969's picture
Update app.py
62a1eec verified
raw
history blame
13.3 kB
# app.py Upscale Images (Real-ESRGAN)
# ---- TorchVision shim (keeps basicsr happy if torchvision isn't installed) ----
import sys, types
try:
import torchvision.transforms.functional_tensor as _ft # noqa: F401
except Exception:
import torch
_mod = types.ModuleType("torchvision.transforms.functional_tensor")
def rgb_to_grayscale(img: "torch.Tensor", num_output_channels: int = 1) -> "torch.Tensor":
if not torch.is_tensor(img):
raise TypeError("rgb_to_grayscale expects a torch.Tensor")
if img.ndim < 3 or img.shape[-3] != 3:
raise ValueError(f"expected tensor with C=3 as the third-from-last dim, got shape {tuple(img.shape)}")
r, g, b = img[..., -3, :, :], img[..., -2, :, :], img[..., -1, :, :]
gray = 0.2989*r + 0.5870*g + 0.1140*b
return torch.stack([gray, gray, gray], dim=-3) if num_output_channels == 3 else gray.unsqueeze(-3)
_mod.rgb_to_grayscale = rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = _mod
# ------------------------------------------------------------------------------
import os, time, zipfile, tempfile, shutil, base64
from pathlib import Path
from typing import List, Optional, Tuple
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import torch
def have_gpu() -> bool:
return torch.cuda.is_available()
if not have_gpu():
print("⚠️ No GPU detected. Upscaling will run on CPU, which may be extremely slow.")
else:
print(f"✅ GPU detected: {torch.cuda.get_device_name(0)}")
def try_load_logo_b64() -> str:
try:
with open("bifrost_logo.png", "rb") as f:
import base64
return base64.b64encode(f.read()).decode("utf-8")
except Exception:
return ""
LOGO_B64 = try_load_logo_b64()
def render_logo_html(px: int = 96) -> str:
img = f'<img src="data:image/png;base64,{LOGO_B64}" style="height:{px}px;width:auto;" />' if LOGO_B64 else ""
return f"""
<div style="display:flex;align-items:center;gap:16px;">
{img}
<div>
<div style="font-size:1.6rem;font-weight:800;">Bifröst · Upscale Images</div>
<div style="opacity:0.8;">Real-ESRGAN (batch click with progress)</div>
</div>
</div>
<hr>
"""
_num = __import__("re").compile(r'(\d+)')
def _natural_key(p: Path | str):
s = str(p)
return [int(t) if t.isdigit() else t.lower() for t in _num.split(s)]
def sample_paths(paths: List[Path] | List[str], n: int = 30) -> List[str]:
if not paths: return []
paths = sorted(paths, key=_natural_key)
total = len(paths); n = max(1, min(n, total))
if n == total: return [str(p) for p in paths]
step = (total - 1) / (n - 1); idxs = [round(i * step) for i in range(n)]
out, seen = [], set()
for i in idxs:
if i not in seen:
out.append(str(paths[int(i)])); seen.add(int(i))
return out
def render_progress(pct: float, label: str = "") -> str:
pct = max(0.0, min(100.0, pct))
return f'''<div style="width:100%;border:1px solid #ddd;border-radius:8px;overflow:hidden;height:18px;">
<div style="height:100%;width:{pct:.1f}%;"></div></div>
<div style="font-size:12px;opacity:.8;margin-top:4px;">{label} {pct:.1f}%</div>'''
def build_rrdb(scale: int, num_block: int):
return _RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=scale)
def _weights_dir() -> str:
wdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights")
os.makedirs(wdir, exist_ok=True)
return wdir
def get_realesrganer(model_id: str, scale: int, tile: int, half: bool, device: str = "cpu") -> RealESRGANer:
wdir = _weights_dir()
if model_id == "x4plus":
model = build_rrdb(scale=4, num_block=23); netscale = 4
urls = ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"]
model_path = os.path.join(wdir, "RealESRGAN_x4plus.pth")
dni_weight = None
elif model_id == "x4plus-anime":
model = build_rrdb(scale=4, num_block=6); netscale = 4
urls = ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"]
model_path = os.path.join(wdir, "RealESRGAN_x4plus_anime_6B.pth")
dni_weight = None
elif model_id == "x2plus":
model = build_rrdb(scale=2, num_block=23); netscale = 2
urls = ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"]
model_path = os.path.join(wdir, "RealESRGAN_x2plus.pth")
dni_weight = None
else:
raise ValueError(f"Unknown model_id: {model_id}")
for url in urls:
fname = os.path.basename(url)
if not os.path.isfile(os.path.join(wdir, fname)):
load_file_from_url(url=url, model_dir=wdir, progress=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
half = (precision == "half") and (device == "cuda")
upsampler = get_realesrganer(model_id, scale, tile, half, device=device)
def _ensure_dir(p: Path) -> Path:
p.mkdir(parents=True, exist_ok=True); return p
def _save_zip_of_dir(dir_path: Path, zip_path: Path) -> str:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for p in sorted(dir_path.glob("*.*"), key=_natural_key):
if p.suffix.lower() in [".jpg", ".jpeg", ".png"]:
zf.write(p, p.name)
return str(zip_path)
def _list_image_paths_from_upload(files: List[gr.File] | None) -> List[str]:
if not files: return []
return [str(Path(f.name)) for f in files if Path(f.name).suffix.lower() in [".jpg",".jpeg",".png"]]
def _build_gallery_from_dir(dir_path: Path, n: int = 30) -> List[str]:
paths = sorted(list(dir_path.glob("*.jpg")) + list(dir_path.glob("*.png")), key=_natural_key)
return sample_paths(paths, n)
def map_ui_model_to_internal(ui_name: str) -> str:
return {
"RealESRGAN_x4plus": "x4plus",
"RealESRGAN_x4plus_anime_6B": "x4plus-anime",
"RealESRGAN_x2plus": "x2plus",
"RealESRNet_x4plus": "x4plus",
"realesr-general-x4v3": "x4plus",
}.get(ui_name, "x4plus")
def clamp_scale_for_model(outscale: int, model_id: str) -> int:
return 2 if model_id == "x2plus" else 4
def step2_prepare_sources(frames_list, uploaded_imgs, max_images):
src = _list_image_paths_from_upload(uploaded_imgs) or (frames_list or [])
if not src:
return [], "", 0, 0, "No images found. Upload files first.", render_progress(0.0, "Idle")
try:
max_images = int(max_images or 0)
except Exception:
max_images = 0
if max_images > 0:
src = src[:max_images]
work = Path(tempfile.mkdtemp(prefix="up_manual_"))
out_dir = _ensure_dir(work / "upscaled")
total = len(src); done_idx = 0
return src, str(out_dir), done_idx, total, f"Sources loaded: {total} image(s). Click 'Process Next Batch'.", render_progress(0.0, "Ready")
def step2_process_next_batch(
up_src_paths, up_out_dir, up_done_idx, up_total,
ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size,
):
if not up_src_paths or not up_out_dir:
yield None, None, "Load sources first.", render_progress(0.0, "Idle"), up_done_idx, up_out_dir
return
model_id = map_ui_model_to_internal(ui_model_name)
scale = clamp_scale_for_model(int(outscale or 4), model_id)
device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
half = (precision == "half") and (device == "cuda")
tile = int(tile or 256)
batch_size = max(1, int(batch_size or 8))
upsampler = get_realesrganer(model_id, scale, tile, half, device=device)
face_enhancer = None
if face_enhance:
try:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
upscale=scale, arch="clean", channel_multiplier=2, bg_upsampler=upsampler
)
except Exception as e:
print("GFPGAN load failed:", e)
start = int(up_done_idx or 0)
end = min(start + batch_size, int(up_total or 0))
out_dir = Path(up_out_dir)
if start >= up_total:
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, "All images processed.", render_progress(100.0, "Done"), start, up_out_dir
return
batch_paths = up_src_paths[start:end]
total_in_batch = len(batch_paths)
t0 = time.time()
for idx, fp in enumerate(batch_paths, start=1):
try:
with Image.open(fp) as im:
img = im.convert("RGB")
cv_img = np.array(img)
if face_enhancer:
_, _, output = face_enhancer.enhance(cv_img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(cv_img, outscale=scale, denoise_strength=float(denoise_strength or 0.5))
Image.fromarray(output).save(out_dir / (Path(fp).stem + ".jpg"), quality=95)
except Exception as e:
print("Upscale error:", e)
elapsed = time.time() - t0
pct_batch = (idx / total_in_batch) * 100.0
eta = (total_in_batch - idx) * (elapsed / max(1, idx))
label = (f"Batch: {idx}/{total_in_batch} · ~{eta:.1f}s ETA · "
f"global {start+idx}/{up_total} (x{scale}, model={ui_model_name})")
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, label, render_progress(pct_batch, f"Upscaling {pct_batch:.0f}% (batch)"), start+idx, up_out_dir
next_idx = end
pct_global = (next_idx / up_total) * 100.0 if up_total else 100.0
gallery = _build_gallery_from_dir(out_dir, 30)
zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
yield gallery, zip_file, f"Processed batch of {total_in_batch}. {next_idx}/{up_total} done.", render_progress(pct_global, "Upscaling… (global)"), next_idx, up_out_dir
def build_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(render_logo_html(88))
gr.Markdown("Upload images and upscale with Real-ESRGAN. Process in batches with live progress.")
frames_state = gr.State([]) # Not used here but kept for simple wiring
up_src_paths_state = gr.State([])
up_out_dir_state = gr.State("")
up_done_idx_state = gr.State(0)
up_total_state = gr.State(0)
imgs_override = gr.Files(label="Upload images (JPG/PNG)", file_types=[".jpg",".jpeg",".png"], type="filepath")
with gr.Accordion("Upscaling options", open=True):
with gr.Row():
ui_model_name = gr.Dropdown(
label="Upscaler model",
choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-general-x4v3"],
value="RealESRGAN_x4plus"
)
denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise (only general-x4v3)")
outscale = gr.Slider(1, 6, value=4, step=1, label="Resolution upscale")
face_enhance = gr.Checkbox(value=False, label="Face Enhancement (GFPGAN)")
with gr.Row():
tile = gr.Number(value=256, label="Tile size (try 128 if OOM; 0=auto)")
precision = gr.Dropdown(["auto", "half", "full"], value="auto", label="Precision (GPU=half, CPU=full)")
with gr.Row():
batch_size = gr.Number(value=12, precision=0, label="Batch size per click")
max_images = gr.Number(value=0, precision=0, label="Max images to process (0 = all)")
with gr.Row():
btn_prepare = gr.Button("Load / Reset Sources", variant="secondary")
btn_next = gr.Button("Process Next Batch", variant="primary")
prog = gr.HTML(render_progress(0.0, "Idle"))
gallery_up = gr.Gallery(label="Upscaled preview (30 sampled)", columns=6, height=480)
zip_up = gr.File(label="Download upscaled ZIP")
details = gr.Markdown("")
btn_prepare.click(
step2_prepare_sources,
inputs=[frames_state, imgs_override, max_images],
outputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state, details, prog]
)
btn_next.click(
step2_process_next_batch,
inputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state, ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size],
outputs=[gallery_up, zip_up, details, prog, up_done_idx_state, up_out_dir_state]
)
return demo
if __name__ == "__main__":
build_ui().queue().launch()