Spaces:
Running
Running
| import os | |
| import math | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image | |
| import gradio as gr | |
| import tempfile | |
| import gdown | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # --------------------------------------------------------------------------- | |
| # Paths & constants | |
| # --------------------------------------------------------------------------- | |
| CACHE_DIR = "/tmp/spectragan" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # Google Drive IDs for ESRGAN ONNX files | |
| DRIVE_IDS = { | |
| "esrgan_x4": "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6", | |
| "hresnet_x4": "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6", # placeholder = ESRGAN x2 | |
| } | |
| # srcnn_x4.pth must be in the Space repo root (same folder as app.py) | |
| SRCNN_PTH = os.path.join(os.path.dirname(__file__), "srcnn_x4.pth") | |
| MODEL_LABELS = { | |
| "esrgan_x4": "Real-ESRGAN ×4", | |
| "srcnn_x4": "SRCNN ×4", | |
| "hresnet_x4": "HResNet ×4", | |
| } | |
| MODEL_SCALES = { | |
| "esrgan_x4": 4, | |
| "srcnn_x4": 4, | |
| "hresnet_x4": 2, # underlying model is ESRGAN x2 (placeholder) | |
| } | |
| # =========================================================================== | |
| # SRCNN architecture — 3 conv layers, 1-channel (Y / grayscale) input | |
| # Your .pth was trained on grayscale, so num_channels=1 here. | |
| # =========================================================================== | |
| class SRCNN(nn.Module): | |
| def __init__(self, num_channels: int = 1): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=4) | |
| self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2) | |
| self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=2) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv3(self.relu(self.conv2(self.relu(self.conv1(x))))) | |
| # =========================================================================== | |
| # Model loading | |
| # =========================================================================== | |
| sess_opts = ort.SessionOptions() | |
| sess_opts.intra_op_num_threads = 2 | |
| sess_opts.inter_op_num_threads = 2 | |
| ONNX_SESSIONS = {} # key → (ort.InferenceSession, input_meta) | |
| SRCNN_MODEL = None | |
| def _load_esrgan_onnx(key: str): | |
| """Download ESRGAN ONNX from Drive via gdown (handles confirmation pages).""" | |
| dest = os.path.join(CACHE_DIR, f"{key}.onnx") | |
| if not os.path.exists(dest): | |
| print(f"Downloading {MODEL_LABELS[key]} from Drive …") | |
| gdown.download(id=DRIVE_IDS[key], output=dest, quiet=False, fuzzy=True) | |
| if os.path.exists(dest): | |
| sess = ort.InferenceSession(dest, sess_options=sess_opts, | |
| providers=["CPUExecutionProvider"]) | |
| ONNX_SESSIONS[key] = (sess, sess.get_inputs()[0]) | |
| print(f"Loaded {MODEL_LABELS[key]} ✓") | |
| else: | |
| print(f"[ERROR] {key} — file missing after download attempt.") | |
| def _load_srcnn_pth(): | |
| """ | |
| Load SRCNN from .pth in the Space repo root. | |
| The weights use 1-channel (grayscale / Y) input — confirmed by the | |
| conv1.weight shape torch.Size([64, 1, 9, 9]) in the checkpoint. | |
| Inference will convert RGB → YCbCr, enhance Y with SRCNN, | |
| bicubic-upsample CbCr, then recompose back to RGB. | |
| """ | |
| global SRCNN_MODEL | |
| if not os.path.exists(SRCNN_PTH): | |
| print(f"[WARN] srcnn_x4.pth not found at {SRCNN_PTH} — SRCNN skipped.") | |
| return | |
| model = SRCNN(num_channels=1) | |
| state = torch.load(SRCNN_PTH, map_location="cpu") | |
| # Unwrap common checkpoint wrappers | |
| for wrap_key in ("model", "state_dict", "params"): | |
| if isinstance(state, dict) and wrap_key in state: | |
| state = state[wrap_key] | |
| state = {k.replace("module.", ""): v for k, v in state.items()} | |
| model.load_state_dict(state, strict=True) | |
| model.eval() | |
| SRCNN_MODEL = model | |
| print("Loaded SRCNN ×4 from .pth ✓ (grayscale/Y-channel model)") | |
| # Boot-time loading | |
| for _k in ("esrgan_x4", "hresnet_x4"): | |
| try: | |
| _load_esrgan_onnx(_k) | |
| except Exception as _e: | |
| print(f"[ERROR] {_k}: {_e}") | |
| try: | |
| _load_srcnn_pth() | |
| except Exception as _e: | |
| print(f"[ERROR] SRCNN: {_e}") | |
| # =========================================================================== | |
| # Inference helpers | |
| # =========================================================================== | |
| def _onnx_tile(sess, meta, tile: np.ndarray) -> np.ndarray: | |
| """HWC float32 [0,1] in → HWC float32 out.""" | |
| patch = tile.transpose(2, 0, 1)[None, ...] | |
| out = sess.run(None, {meta.name: patch})[0] | |
| return out.squeeze(0).transpose(1, 2, 0) | |
| def _srcnn_tile(tile: np.ndarray, scale: int = 4) -> np.ndarray: | |
| """ | |
| Enhance a single RGB tile using the grayscale SRCNN model. | |
| Strategy: split into YCbCr → SRCNN on Y → bicubic CbCr → recompose RGB. | |
| tile: HWC float32 [0, 1] | |
| returns: HWC float32 [0, 1] at scale× resolution | |
| """ | |
| tile_uint8 = (np.clip(tile, 0, 1) * 255).round().astype(np.uint8) | |
| tile_pil = Image.fromarray(tile_uint8) | |
| tile_ycbcr = tile_pil.convert("YCbCr") | |
| y_pil, cb_pil, cr_pil = tile_ycbcr.split() | |
| orig_w, orig_h = tile_pil.size | |
| up_w, up_h = orig_w * scale, orig_h * scale | |
| # Upsample CbCr channels with bicubic (no SRCNN needed there) | |
| cb_up = cb_pil.resize((up_w, up_h), Image.BICUBIC) | |
| cr_up = cr_pil.resize((up_w, up_h), Image.BICUBIC) | |
| # Bicubic upsample Y, then refine with SRCNN | |
| y_arr = np.array(y_pil).astype(np.float32) / 255.0 # (H, W) | |
| y_t = torch.from_numpy(y_arr).unsqueeze(0).unsqueeze(0) # (1, 1, H, W) | |
| y_up = F.interpolate(y_t, size=(up_h, up_w), mode="bicubic", align_corners=False) | |
| with torch.no_grad(): | |
| y_refined = SRCNN_MODEL(y_up) # (1, 1, H*s, W*s) | |
| y_out = (y_refined.squeeze().numpy() * 255.0).clip(0, 255).round().astype(np.uint8) | |
| y_up_pil = Image.fromarray(y_out, mode="L") | |
| # Recompose YCbCr → RGB | |
| out_rgb = Image.merge("YCbCr", [y_up_pil, cb_up, cr_up]).convert("RGB") | |
| return np.array(out_rgb).astype(np.float32) / 255.0 | |
| def upscale(input_img: Image.Image, model_key: str, max_dim: int = 1024) -> Image.Image: | |
| """Tile-based upscale dispatcher for ONNX (ESRGAN) and torch (SRCNN).""" | |
| if model_key == "srcnn_x4" and SRCNN_MODEL is None: | |
| raise RuntimeError("SRCNN model not loaded — check that srcnn_x4.pth is in the repo root.") | |
| if model_key in ("esrgan_x4", "hresnet_x4") and model_key not in ONNX_SESSIONS: | |
| raise RuntimeError(f"{MODEL_LABELS[model_key]} failed to load at startup.") | |
| scale = MODEL_SCALES[model_key] | |
| TILE = 128 # LR tile size (consistent across all models) | |
| # Cap input size to avoid OOM | |
| w, h = input_img.size | |
| if w > max_dim or h > max_dim: | |
| factor = max_dim / float(max(w, h)) | |
| input_img = input_img.resize((int(w * factor), int(h * factor)), Image.LANCZOS) | |
| arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0 | |
| h_orig, w_orig, _ = arr.shape | |
| tiles_h = math.ceil(h_orig / TILE) | |
| tiles_w = math.ceil(w_orig / TILE) | |
| arr_pad = np.pad( | |
| arr, | |
| ((0, tiles_h * TILE - h_orig), (0, tiles_w * TILE - w_orig), (0, 0)), | |
| mode="reflect", | |
| ) | |
| out = np.zeros((tiles_h * TILE * scale, tiles_w * TILE * scale, 3), dtype=np.float32) | |
| for i in range(tiles_h): | |
| for j in range(tiles_w): | |
| y0, x0 = i * TILE, j * TILE | |
| tile = arr_pad[y0:y0 + TILE, x0:x0 + TILE] | |
| if model_key == "srcnn_x4": | |
| up_tile = _srcnn_tile(tile, scale=scale) | |
| else: | |
| sess, meta = ONNX_SESSIONS[model_key] | |
| up_tile = _onnx_tile(sess, meta, tile) | |
| oy0, ox0 = i * TILE * scale, j * TILE * scale | |
| out[oy0:oy0 + TILE * scale, ox0:ox0 + TILE * scale] = up_tile | |
| final = np.clip(out[:h_orig * scale, :w_orig * scale], 0.0, 1.0) | |
| return Image.fromarray((final * 255.0).round().astype(np.uint8)) | |
| # =========================================================================== | |
| # Gradio callback | |
| # =========================================================================== | |
| def run_upscale(input_img: Image.Image, model_name: str): | |
| if input_img is None: | |
| return None, None, None | |
| key = next(k for k, v in MODEL_LABELS.items() if v == model_name) | |
| result = upscale(input_img, key) | |
| # Resize original to same dimensions as output for the slider | |
| up_w, up_h = result.size | |
| orig_resized = input_img.resize((up_w, up_h), Image.LANCZOS).convert("RGB") | |
| # Save both as temp files for ImageSlider | |
| tmp_orig = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| orig_resized.save(tmp_orig.name) | |
| tmp_orig.close() | |
| tmp_up = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| result.save(tmp_up.name) | |
| tmp_up.close() | |
| # Separate download copy | |
| tmp_dl = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| result.save(tmp_dl.name, format="PNG") | |
| tmp_dl.close() | |
| return (tmp_orig.name, tmp_up.name), result, tmp_dl.name | |
| # =========================================================================== | |
| # Gradio UI | |
| # =========================================================================== | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=DM+Sans:wght@400;600;700&display=swap'); | |
| body, .gradio-container { font-family: 'DM Sans', sans-serif !important; } | |
| #title { text-align: center; padding: 24px 0 8px; } | |
| #title h1 { font-size: 2rem; font-weight: 700; letter-spacing: -0.5px; margin: 0; } | |
| #title p { color: #666; margin: 4px 0 0; } | |
| #run-btn { | |
| background: linear-gradient(135deg, #0f0c29, #302b63, #24243e) !important; | |
| color: #fff !important; font-weight: 700 !important; | |
| font-size: 1rem !important; border-radius: 10px !important; | |
| padding: 14px 0 !important; width: 100%; letter-spacing: 0.03em; | |
| } | |
| #run-btn:hover { opacity: 0.85; } | |
| #dl-btn button { | |
| background: #f4f4f4 !important; border: 1px solid #ddd !important; | |
| color: #333 !important; border-radius: 8px !important; | |
| width: 100%; font-size: 0.85rem !important; | |
| } | |
| .section-label { | |
| font-size: 0.75rem; font-weight: 700; letter-spacing: 0.1em; | |
| text-transform: uppercase; color: #999; margin-bottom: 6px; | |
| } | |
| """ | |
| dropdown_choices = list(MODEL_LABELS.values()) | |
| with gr.Blocks(css=css, title="SpectraGAN Upscaler") as demo: | |
| gr.HTML(""" | |
| <div id="title"> | |
| <h1>🖼️ SpectraGAN Upscaler</h1> | |
| <p>Choose a model, upscale your image, and drag the slider to compare.</p> | |
| </div> | |
| """) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=260): | |
| gr.HTML('<div class="section-label">Source Image</div>') | |
| inp_image = gr.Image(type="pil", show_label=False, height=260) | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Model</div>') | |
| model_dropdown = gr.Dropdown( | |
| choices=dropdown_choices, | |
| value=dropdown_choices[0], | |
| show_label=False, | |
| ) | |
| run_btn = gr.Button("⚡ Upscale", elem_id="run-btn") | |
| dl_btn = gr.DownloadButton( | |
| label="⬇ Download upscaled PNG", | |
| elem_id="dl-btn", | |
| visible=True, | |
| ) | |
| with gr.Column(scale=2): | |
| gr.HTML('<div class="section-label">Before / After — drag to compare</div>') | |
| slider = gr.ImageSlider( | |
| show_label=False, | |
| height=420, | |
| type="filepath", | |
| ) | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Upscaled Preview</div>') | |
| out_preview = gr.Image(type="pil", show_label=False, height=200) | |
| run_btn.click( | |
| fn=run_upscale, | |
| inputs=[inp_image, model_dropdown], | |
| outputs=[slider, out_preview, dl_btn], | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |