import os import math import numpy as np import onnxruntime as ort from PIL import Image import gradio as gr import tempfile import gdown import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Paths & constants # --------------------------------------------------------------------------- CACHE_DIR = "/tmp/spectragan" os.makedirs(CACHE_DIR, exist_ok=True) # Google Drive IDs for ESRGAN ONNX files DRIVE_IDS = { "esrgan_x4": "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6", "hresnet_x4": "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6", # placeholder = ESRGAN x2 } # srcnn_x4.pth must be in the Space repo root (same folder as app.py) SRCNN_PTH = os.path.join(os.path.dirname(__file__), "srcnn_x4.pth") MODEL_LABELS = { "esrgan_x4": "Real-ESRGAN ×4", "srcnn_x4": "SRCNN ×4", "hresnet_x4": "HResNet ×4", } MODEL_SCALES = { "esrgan_x4": 4, "srcnn_x4": 4, "hresnet_x4": 2, # underlying model is ESRGAN x2 (placeholder) } # =========================================================================== # SRCNN architecture — 3 conv layers, 1-channel (Y / grayscale) input # Your .pth was trained on grayscale, so num_channels=1 here. # =========================================================================== class SRCNN(nn.Module): def __init__(self, num_channels: int = 1): super().__init__() self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=4) self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2) self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=2) self.relu = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv3(self.relu(self.conv2(self.relu(self.conv1(x))))) # =========================================================================== # Model loading # =========================================================================== sess_opts = ort.SessionOptions() sess_opts.intra_op_num_threads = 2 sess_opts.inter_op_num_threads = 2 ONNX_SESSIONS = {} # key → (ort.InferenceSession, input_meta) SRCNN_MODEL = None def _load_esrgan_onnx(key: str): """Download ESRGAN ONNX from Drive via gdown (handles confirmation pages).""" dest = os.path.join(CACHE_DIR, f"{key}.onnx") if not os.path.exists(dest): print(f"Downloading {MODEL_LABELS[key]} from Drive …") gdown.download(id=DRIVE_IDS[key], output=dest, quiet=False, fuzzy=True) if os.path.exists(dest): sess = ort.InferenceSession(dest, sess_options=sess_opts, providers=["CPUExecutionProvider"]) ONNX_SESSIONS[key] = (sess, sess.get_inputs()[0]) print(f"Loaded {MODEL_LABELS[key]} ✓") else: print(f"[ERROR] {key} — file missing after download attempt.") def _load_srcnn_pth(): """ Load SRCNN from .pth in the Space repo root. The weights use 1-channel (grayscale / Y) input — confirmed by the conv1.weight shape torch.Size([64, 1, 9, 9]) in the checkpoint. Inference will convert RGB → YCbCr, enhance Y with SRCNN, bicubic-upsample CbCr, then recompose back to RGB. """ global SRCNN_MODEL if not os.path.exists(SRCNN_PTH): print(f"[WARN] srcnn_x4.pth not found at {SRCNN_PTH} — SRCNN skipped.") return model = SRCNN(num_channels=1) state = torch.load(SRCNN_PTH, map_location="cpu") # Unwrap common checkpoint wrappers for wrap_key in ("model", "state_dict", "params"): if isinstance(state, dict) and wrap_key in state: state = state[wrap_key] state = {k.replace("module.", ""): v for k, v in state.items()} model.load_state_dict(state, strict=True) model.eval() SRCNN_MODEL = model print("Loaded SRCNN ×4 from .pth ✓ (grayscale/Y-channel model)") # Boot-time loading for _k in ("esrgan_x4", "hresnet_x4"): try: _load_esrgan_onnx(_k) except Exception as _e: print(f"[ERROR] {_k}: {_e}") try: _load_srcnn_pth() except Exception as _e: print(f"[ERROR] SRCNN: {_e}") # =========================================================================== # Inference helpers # =========================================================================== def _onnx_tile(sess, meta, tile: np.ndarray) -> np.ndarray: """HWC float32 [0,1] in → HWC float32 out.""" patch = tile.transpose(2, 0, 1)[None, ...] out = sess.run(None, {meta.name: patch})[0] return out.squeeze(0).transpose(1, 2, 0) def _srcnn_tile(tile: np.ndarray, scale: int = 4) -> np.ndarray: """ Enhance a single RGB tile using the grayscale SRCNN model. Strategy: split into YCbCr → SRCNN on Y → bicubic CbCr → recompose RGB. tile: HWC float32 [0, 1] returns: HWC float32 [0, 1] at scale× resolution """ tile_uint8 = (np.clip(tile, 0, 1) * 255).round().astype(np.uint8) tile_pil = Image.fromarray(tile_uint8) tile_ycbcr = tile_pil.convert("YCbCr") y_pil, cb_pil, cr_pil = tile_ycbcr.split() orig_w, orig_h = tile_pil.size up_w, up_h = orig_w * scale, orig_h * scale # Upsample CbCr channels with bicubic (no SRCNN needed there) cb_up = cb_pil.resize((up_w, up_h), Image.BICUBIC) cr_up = cr_pil.resize((up_w, up_h), Image.BICUBIC) # Bicubic upsample Y, then refine with SRCNN y_arr = np.array(y_pil).astype(np.float32) / 255.0 # (H, W) y_t = torch.from_numpy(y_arr).unsqueeze(0).unsqueeze(0) # (1, 1, H, W) y_up = F.interpolate(y_t, size=(up_h, up_w), mode="bicubic", align_corners=False) with torch.no_grad(): y_refined = SRCNN_MODEL(y_up) # (1, 1, H*s, W*s) y_out = (y_refined.squeeze().numpy() * 255.0).clip(0, 255).round().astype(np.uint8) y_up_pil = Image.fromarray(y_out, mode="L") # Recompose YCbCr → RGB out_rgb = Image.merge("YCbCr", [y_up_pil, cb_up, cr_up]).convert("RGB") return np.array(out_rgb).astype(np.float32) / 255.0 def upscale(input_img: Image.Image, model_key: str, max_dim: int = 1024) -> Image.Image: """Tile-based upscale dispatcher for ONNX (ESRGAN) and torch (SRCNN).""" if model_key == "srcnn_x4" and SRCNN_MODEL is None: raise RuntimeError("SRCNN model not loaded — check that srcnn_x4.pth is in the repo root.") if model_key in ("esrgan_x4", "hresnet_x4") and model_key not in ONNX_SESSIONS: raise RuntimeError(f"{MODEL_LABELS[model_key]} failed to load at startup.") scale = MODEL_SCALES[model_key] TILE = 128 # LR tile size (consistent across all models) # Cap input size to avoid OOM w, h = input_img.size if w > max_dim or h > max_dim: factor = max_dim / float(max(w, h)) input_img = input_img.resize((int(w * factor), int(h * factor)), Image.LANCZOS) arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0 h_orig, w_orig, _ = arr.shape tiles_h = math.ceil(h_orig / TILE) tiles_w = math.ceil(w_orig / TILE) arr_pad = np.pad( arr, ((0, tiles_h * TILE - h_orig), (0, tiles_w * TILE - w_orig), (0, 0)), mode="reflect", ) out = np.zeros((tiles_h * TILE * scale, tiles_w * TILE * scale, 3), dtype=np.float32) for i in range(tiles_h): for j in range(tiles_w): y0, x0 = i * TILE, j * TILE tile = arr_pad[y0:y0 + TILE, x0:x0 + TILE] if model_key == "srcnn_x4": up_tile = _srcnn_tile(tile, scale=scale) else: sess, meta = ONNX_SESSIONS[model_key] up_tile = _onnx_tile(sess, meta, tile) oy0, ox0 = i * TILE * scale, j * TILE * scale out[oy0:oy0 + TILE * scale, ox0:ox0 + TILE * scale] = up_tile final = np.clip(out[:h_orig * scale, :w_orig * scale], 0.0, 1.0) return Image.fromarray((final * 255.0).round().astype(np.uint8)) # =========================================================================== # Gradio callback # =========================================================================== def run_upscale(input_img: Image.Image, model_name: str): if input_img is None: return None, None, None key = next(k for k, v in MODEL_LABELS.items() if v == model_name) result = upscale(input_img, key) # Resize original to same dimensions as output for the slider up_w, up_h = result.size orig_resized = input_img.resize((up_w, up_h), Image.LANCZOS).convert("RGB") # Save both as temp files for ImageSlider tmp_orig = tempfile.NamedTemporaryFile(delete=False, suffix=".png") orig_resized.save(tmp_orig.name) tmp_orig.close() tmp_up = tempfile.NamedTemporaryFile(delete=False, suffix=".png") result.save(tmp_up.name) tmp_up.close() # Separate download copy tmp_dl = tempfile.NamedTemporaryFile(delete=False, suffix=".png") result.save(tmp_dl.name, format="PNG") tmp_dl.close() return (tmp_orig.name, tmp_up.name), result, tmp_dl.name # =========================================================================== # Gradio UI # =========================================================================== css = """ @import url('https://fonts.googleapis.com/css2?family=DM+Sans:wght@400;600;700&display=swap'); body, .gradio-container { font-family: 'DM Sans', sans-serif !important; } #title { text-align: center; padding: 24px 0 8px; } #title h1 { font-size: 2rem; font-weight: 700; letter-spacing: -0.5px; margin: 0; } #title p { color: #666; margin: 4px 0 0; } #run-btn { background: linear-gradient(135deg, #0f0c29, #302b63, #24243e) !important; color: #fff !important; font-weight: 700 !important; font-size: 1rem !important; border-radius: 10px !important; padding: 14px 0 !important; width: 100%; letter-spacing: 0.03em; } #run-btn:hover { opacity: 0.85; } #dl-btn button { background: #f4f4f4 !important; border: 1px solid #ddd !important; color: #333 !important; border-radius: 8px !important; width: 100%; font-size: 0.85rem !important; } .section-label { font-size: 0.75rem; font-weight: 700; letter-spacing: 0.1em; text-transform: uppercase; color: #999; margin-bottom: 6px; } """ dropdown_choices = list(MODEL_LABELS.values()) with gr.Blocks(css=css, title="SpectraGAN Upscaler") as demo: gr.HTML("""
Choose a model, upscale your image, and drag the slider to compare.