Spaces:
Running
Running
ParamAhuja commited on
Commit Β·
3262d11
1
Parent(s): f376a33
initial
Browse files- README.md +175 -5
- app.py +363 -0
- requirements.txt +9 -0
README.md
CHANGED
|
@@ -1,13 +1,183 @@
|
|
| 1 |
---
|
| 2 |
title: SpectraGAN
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license:
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: SpectraGAN
|
| 3 |
+
emoji: πΌοΈ
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.31.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# πΌοΈ SpectraGAN β Multi-Model Upscaler Comparison
|
| 14 |
+
|
| 15 |
+
A Gradio web app that lets you upscale an image with **multiple SR models simultaneously** and compare results side by side.
|
| 16 |
+
|
| 17 |
+
Supported models:
|
| 18 |
+
|
| 19 |
+
| Model | Architecture | Scale |
|
| 20 |
+
|-------|-------------|-------|
|
| 21 |
+
| Real-ESRGAN Γ2 | GAN (residual-in-residual dense block) | Γ2 |
|
| 22 |
+
| Real-ESRGAN Γ4 | GAN (residual-in-residual dense block) | Γ4 |
|
| 23 |
+
| SRCNN Γ4 | Shallow 3-layer CNN | Γ4 |
|
| 24 |
+
| HResNet Γ4 | Deep residual network (EDSR-style) | Γ4 |
|
| 25 |
+
| SR3 *(stub)* | Diffusion model | Γ4 β see note below |
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## π Table of Contents
|
| 30 |
+
|
| 31 |
+
1. [Features](#features)
|
| 32 |
+
2. [Project Structure](#project-structure)
|
| 33 |
+
3. [Prerequisites](#prerequisites)
|
| 34 |
+
4. [Installation](#installation)
|
| 35 |
+
5. [Adding Your ONNX Models](#adding-your-onnx-models)
|
| 36 |
+
6. [Running Locally](#running-locally)
|
| 37 |
+
7. [SR3 Integration Guide](#sr3-integration-guide)
|
| 38 |
+
8. [Contributing](#contributing)
|
| 39 |
+
9. [License](#license)
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## β¨ Features
|
| 44 |
+
|
| 45 |
+
- **Side-by-side comparison** β run up to 4 models at once, results displayed in a 4-panel grid.
|
| 46 |
+
- **Selective execution** β toggle any model on/off before running; unchecked models are skipped.
|
| 47 |
+
- **Γ8 post-resize** β optionally apply a bicubic Γ2 pass on top of any Γ4 result.
|
| 48 |
+
- **Tile-based inference** β large images are split into tiles matching each model's fixed input size, then stitched back together seamlessly.
|
| 49 |
+
- **Per-result download** β each panel has its own PNG download button.
|
| 50 |
+
- **Graceful degradation** β if a model file is missing (e.g. Drive ID not yet set), that panel is skipped without crashing the others.
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
## π Project Structure
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
spectragan/
|
| 58 |
+
βββ model/
|
| 59 |
+
β βββ Real-ESRGAN_x2plus.onnx # auto-downloaded
|
| 60 |
+
β βββ Real-ESRGAN-x4plus.onnx # auto-downloaded
|
| 61 |
+
β βββ SRCNN_x4.onnx # you provide β see below
|
| 62 |
+
β βββ HResNet_x4.onnx # you provide β see below
|
| 63 |
+
βββ app.py
|
| 64 |
+
βββ requirements.txt
|
| 65 |
+
βββ README.md
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## βοΈ Prerequisites
|
| 71 |
+
|
| 72 |
+
- Python 3.10+
|
| 73 |
+
- `git`
|
| 74 |
+
- A terminal / command prompt
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## π§ Installation
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
git clone https://github.com/ParamAhuja/SpectraGAN.git
|
| 82 |
+
cd SpectraGAN
|
| 83 |
+
python -m venv .venv
|
| 84 |
+
source .venv/bin/activate # Linux/macOS
|
| 85 |
+
# .venv\Scripts\activate # Windows
|
| 86 |
+
pip install -r requirements.txt
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## ποΈ Adding Your ONNX Models
|
| 92 |
+
|
| 93 |
+
The Real-ESRGAN weights are downloaded automatically from Google Drive on first run.
|
| 94 |
+
|
| 95 |
+
For **SRCNN** and **HResNet** you need to:
|
| 96 |
+
|
| 97 |
+
1. Export your trained PyTorch model to ONNX:
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
import torch
|
| 101 |
+
|
| 102 |
+
# SRCNN example
|
| 103 |
+
from srcnn import SRCNN
|
| 104 |
+
model = SRCNN()
|
| 105 |
+
model.load_state_dict(torch.load("srcnn.pth"))
|
| 106 |
+
model.eval()
|
| 107 |
+
|
| 108 |
+
dummy = torch.randn(1, 3, 128, 128)
|
| 109 |
+
torch.onnx.export(
|
| 110 |
+
model, dummy, "SRCNN_x4.onnx",
|
| 111 |
+
input_names=["input"], output_names=["output"],
|
| 112 |
+
dynamic_axes={"input": {2: "H", 3: "W"}, "output": {2: "H", 3: "W"}}
|
| 113 |
+
)
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
2. Upload the `.onnx` file to Google Drive and set **"Anyone with the link can view"**.
|
| 117 |
+
|
| 118 |
+
3. Copy the file ID from the share URL and update `DRIVE_IDS` in `app.py`:
|
| 119 |
+
|
| 120 |
+
```python
|
| 121 |
+
DRIVE_IDS = {
|
| 122 |
+
...
|
| 123 |
+
"srcnn_x4": "YOUR_SRCNN_DRIVE_FILE_ID_HERE",
|
| 124 |
+
"hresnet_x4": "YOUR_HRESNET_DRIVE_FILE_ID_HERE",
|
| 125 |
+
}
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
---
|
| 129 |
+
|
| 130 |
+
## π Running Locally
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
python app.py
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Open `http://127.0.0.1:7860` in your browser.
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## π SR3 Integration Guide
|
| 141 |
+
|
| 142 |
+
SR3 (Super-Resolution via Repeated Refinement) is a **diffusion model** β it cannot be exported to a static ONNX graph because its inference involves a variable-length denoising loop.
|
| 143 |
+
|
| 144 |
+
To add SR3:
|
| 145 |
+
|
| 146 |
+
1. Clone the reference implementation:
|
| 147 |
+
```bash
|
| 148 |
+
git clone https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
2. Place your trained checkpoint at `model/sr3_x4.pth`.
|
| 152 |
+
|
| 153 |
+
3. Add `torch` and `torchvision` to `requirements.txt`.
|
| 154 |
+
|
| 155 |
+
4. Write a wrapper in `app.py`:
|
| 156 |
+
```python
|
| 157 |
+
def run_sr3(input_img: Image.Image) -> Image.Image:
|
| 158 |
+
# load config + model, run the denoising loop, return result
|
| 159 |
+
...
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
5. Add `"sr3_x4"` to the `PANEL_KEYS` list and wire `run_sr3` into `compare_models`.
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## π€ Contributing
|
| 167 |
+
|
| 168 |
+
Pull requests welcome. Please open an issue first to discuss significant changes.
|
| 169 |
+
|
| 170 |
+
---
|
| 171 |
+
|
| 172 |
+
## π License
|
| 173 |
+
|
| 174 |
+
Apache 2.0 β see `LICENSE`.
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## π€ Author & Credits
|
| 179 |
+
|
| 180 |
+
- Real-ESRGAN by [xinntao](https://github.com/xinntao/Real-ESRGAN)
|
| 181 |
+
- SRCNN by Dong et al. (2014)
|
| 182 |
+
- HResNet / EDSR by Lim et al. (2017)
|
| 183 |
+
- SR3 by Ho et al. (2022) β [paper](https://arxiv.org/abs/2104.07636)
|
app.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import uuid
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import tempfile
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# Directory & model paths
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
MODEL_DIR = "model"
|
| 15 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
MODEL_PATHS = {
|
| 18 |
+
"esrgan_x2": os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx"),
|
| 19 |
+
"esrgan_x4": os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx"),
|
| 20 |
+
"srcnn_x4": os.path.join(MODEL_DIR, "SRCNN_x4.onnx"),
|
| 21 |
+
"hresnet_x4": os.path.join(MODEL_DIR, "HResNet_x4.onnx"),
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Google Drive file IDs
|
| 26 |
+
# TODO: Replace the SRCNN / HResNet IDs with your own uploaded ONNX exports.
|
| 27 |
+
# Steps to export SRCNN to ONNX:
|
| 28 |
+
# import torch; from srcnn import SRCNN
|
| 29 |
+
# model = SRCNN(); model.load_state_dict(torch.load("srcnn.pth"))
|
| 30 |
+
# dummy = torch.randn(1, 3, 128, 128)
|
| 31 |
+
# torch.onnx.export(model, dummy, "SRCNN_x4.onnx",
|
| 32 |
+
# input_names=["input"], output_names=["output"],
|
| 33 |
+
# dynamic_axes={"input":{2:"H",3:"W"},"output":{2:"H",3:"W"}})
|
| 34 |
+
# Same pattern applies for HResNet / EDSR.
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
DRIVE_IDS = {
|
| 37 |
+
"esrgan_x2": "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6",
|
| 38 |
+
"esrgan_x4": "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6",
|
| 39 |
+
"srcnn_x4": "YOUR_SRCNN_DRIVE_FILE_ID_HERE", # <-- replace
|
| 40 |
+
"hresnet_x4": "YOUR_HRESNET_DRIVE_FILE_ID_HERE", # <-- replace
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Scale factor each model produces
|
| 44 |
+
MODEL_SCALES = {
|
| 45 |
+
"esrgan_x2": 2,
|
| 46 |
+
"esrgan_x4": 4,
|
| 47 |
+
"srcnn_x4": 4,
|
| 48 |
+
"hresnet_x4": 4,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Human-readable labels shown in the UI
|
| 52 |
+
MODEL_LABELS = {
|
| 53 |
+
"esrgan_x2": "Real-ESRGAN Γ2",
|
| 54 |
+
"esrgan_x4": "Real-ESRGAN Γ4",
|
| 55 |
+
"srcnn_x4": "SRCNN Γ4",
|
| 56 |
+
"hresnet_x4": "HResNet Γ4",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
# SR3 NOTE
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# SR3 (Super-Resolution via Repeated Refinement, Ho et al. 2022) is a
|
| 63 |
+
# *diffusion model* and cannot be exported to ONNX in the same way as
|
| 64 |
+
# feed-forward CNNs. To integrate SR3:
|
| 65 |
+
# 1. Clone the official repo: https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement
|
| 66 |
+
# 2. Place your checkpoint in model/sr3_x4.pth
|
| 67 |
+
# 3. Add a `run_sr3(image: Image) -> Image` function that runs the
|
| 68 |
+
# denoising loop using PyTorch directly (add `torch` to requirements).
|
| 69 |
+
# 4. Wire `run_sr3` into `upscale_one_model` for key "sr3_x4".
|
| 70 |
+
# SR3 is intentionally omitted from the ONNX pipeline to avoid misleading
|
| 71 |
+
# model shape assumptions.
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
# Google Drive downloader
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
def download_from_drive(file_id: str, dest_path: str):
|
| 79 |
+
URL = "https://drive.google.com/uc?export=download"
|
| 80 |
+
session = requests.Session()
|
| 81 |
+
response = session.get(URL, params={"id": file_id}, stream=True)
|
| 82 |
+
token = None
|
| 83 |
+
for key, value in response.cookies.items():
|
| 84 |
+
if key.startswith("download_warning"):
|
| 85 |
+
token = value
|
| 86 |
+
break
|
| 87 |
+
if token:
|
| 88 |
+
response = session.get(URL, params={"id": file_id, "confirm": token}, stream=True)
|
| 89 |
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
| 90 |
+
with open(dest_path, "wb") as f:
|
| 91 |
+
for chunk in response.iter_content(chunk_size=32768):
|
| 92 |
+
if chunk:
|
| 93 |
+
f.write(chunk)
|
| 94 |
+
print(f"Downloaded β {dest_path}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
# Download models if missing
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
for key, path in MODEL_PATHS.items():
|
| 101 |
+
if not os.path.isfile(path):
|
| 102 |
+
file_id = DRIVE_IDS[key]
|
| 103 |
+
if file_id.startswith("YOUR_"):
|
| 104 |
+
print(f"[WARN] Skipping {key}: Google Drive ID not set. "
|
| 105 |
+
"Update DRIVE_IDS in app.py with your ONNX export.")
|
| 106 |
+
else:
|
| 107 |
+
print(f"Downloading {MODEL_LABELS[key]} β¦")
|
| 108 |
+
download_from_drive(file_id, path)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# Load ONNX sessions (only for models that have been downloaded)
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
sess_opts = ort.SessionOptions()
|
| 115 |
+
sess_opts.intra_op_num_threads = 2
|
| 116 |
+
sess_opts.inter_op_num_threads = 2
|
| 117 |
+
|
| 118 |
+
SESSIONS = {} # key β ort.InferenceSession
|
| 119 |
+
INPUT_SHAPES = {} # key β (H_in, W_in)
|
| 120 |
+
|
| 121 |
+
for key, path in MODEL_PATHS.items():
|
| 122 |
+
if os.path.isfile(path):
|
| 123 |
+
try:
|
| 124 |
+
sess = ort.InferenceSession(
|
| 125 |
+
path,
|
| 126 |
+
sess_options=sess_opts,
|
| 127 |
+
providers=["CPUExecutionProvider"]
|
| 128 |
+
)
|
| 129 |
+
meta = sess.get_inputs()[0]
|
| 130 |
+
shape = tuple(meta.shape)
|
| 131 |
+
# shape is (1, 3, H, W) for fixed-size models, or (1,3,None,None) for dynamic
|
| 132 |
+
h = int(shape[2]) if shape[2] is not None and str(shape[2]).isdigit() else 128
|
| 133 |
+
w = int(shape[3]) if shape[3] is not None and str(shape[3]).isdigit() else 128
|
| 134 |
+
SESSIONS[key] = (sess, meta)
|
| 135 |
+
INPUT_SHAPES[key] = (h, w)
|
| 136 |
+
print(f"Loaded {MODEL_LABELS[key]} tile={h}Γ{w}")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"[ERROR] Could not load {key}: {e}")
|
| 139 |
+
else:
|
| 140 |
+
print(f"[INFO] {key} not available β will be skipped in comparisons.")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
# Tile-based upscale for a single ONNX model
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
def run_onnx_tile(sess, meta, tile_np: np.ndarray) -> np.ndarray:
|
| 147 |
+
"""Run one tile through any ONNX session. tile_np is HWC float32 [0,1]."""
|
| 148 |
+
patch = np.transpose(tile_np, (2, 0, 1))[None, ...] # NCHW
|
| 149 |
+
out = sess.run(None, {meta.name: patch})[0]
|
| 150 |
+
out = np.squeeze(out, axis=0)
|
| 151 |
+
return np.transpose(out, (1, 2, 0)) # back to HWC
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def tile_upscale_model(input_img: Image.Image, key: str, max_dim: int = 1024) -> Image.Image:
|
| 155 |
+
"""
|
| 156 |
+
Upscale *input_img* using the ONNX model identified by *key*.
|
| 157 |
+
Returns a PIL Image at the upscaled resolution.
|
| 158 |
+
"""
|
| 159 |
+
if key not in SESSIONS:
|
| 160 |
+
raise ValueError(f"Model '{key}' is not loaded. Check the Drive ID / path.")
|
| 161 |
+
|
| 162 |
+
sess, meta = SESSIONS[key]
|
| 163 |
+
H_in, W_in = INPUT_SHAPES[key]
|
| 164 |
+
scale = MODEL_SCALES[key]
|
| 165 |
+
|
| 166 |
+
# Optionally cap input size to avoid OOM on large images
|
| 167 |
+
w, h = input_img.size
|
| 168 |
+
if w > max_dim or h > max_dim:
|
| 169 |
+
factor = max_dim / float(max(w, h))
|
| 170 |
+
input_img = input_img.resize((int(w * factor), int(h * factor)), Image.LANCZOS)
|
| 171 |
+
|
| 172 |
+
arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0
|
| 173 |
+
h_orig, w_orig, _ = arr.shape
|
| 174 |
+
|
| 175 |
+
tiles_h = math.ceil(h_orig / H_in)
|
| 176 |
+
tiles_w = math.ceil(w_orig / W_in)
|
| 177 |
+
pad_h = tiles_h * H_in - h_orig
|
| 178 |
+
pad_w = tiles_w * W_in - w_orig
|
| 179 |
+
|
| 180 |
+
arr_padded = np.pad(arr, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
|
| 181 |
+
out_arr = np.zeros((tiles_h * H_in * scale, tiles_w * W_in * scale, 3), dtype=np.float32)
|
| 182 |
+
|
| 183 |
+
for i in range(tiles_h):
|
| 184 |
+
for j in range(tiles_w):
|
| 185 |
+
y0, x0 = i * H_in, j * W_in
|
| 186 |
+
tile = arr_padded[y0:y0 + H_in, x0:x0 + W_in, :]
|
| 187 |
+
up_tile = run_onnx_tile(sess, meta, tile)
|
| 188 |
+
oy0, ox0 = i * H_in * scale, j * W_in * scale
|
| 189 |
+
out_arr[oy0:oy0 + H_in * scale, ox0:ox0 + W_in * scale, :] = up_tile
|
| 190 |
+
|
| 191 |
+
final = np.clip(out_arr[0:h_orig * scale, 0:w_orig * scale, :], 0.0, 1.0)
|
| 192 |
+
return Image.fromarray((final * 255.0).round().astype(np.uint8))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def upscale_8x_from_4x(input_img: Image.Image, key: str) -> Image.Image:
|
| 196 |
+
"""Run Γ4 model, then bicubic Γ2 to reach Γ8."""
|
| 197 |
+
img_4x = tile_upscale_model(input_img, key)
|
| 198 |
+
w, h = input_img.size
|
| 199 |
+
return img_4x.resize((w * 8, h * 8), Image.LANCZOS)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
# Core comparison function (called by the Gradio button)
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
def compare_models(
|
| 206 |
+
input_img: Image.Image,
|
| 207 |
+
use_esrgan_x2: bool,
|
| 208 |
+
use_esrgan_x4: bool,
|
| 209 |
+
use_srcnn: bool,
|
| 210 |
+
use_hresnet: bool,
|
| 211 |
+
include_8x: bool,
|
| 212 |
+
):
|
| 213 |
+
if input_img is None:
|
| 214 |
+
return [None] * 8 # 4 preview + 4 download slots
|
| 215 |
+
|
| 216 |
+
selection = []
|
| 217 |
+
if use_esrgan_x2: selection.append("esrgan_x2")
|
| 218 |
+
if use_esrgan_x4: selection.append("esrgan_x4")
|
| 219 |
+
if use_srcnn: selection.append("srcnn_x4")
|
| 220 |
+
if use_hresnet: selection.append("hresnet_x4")
|
| 221 |
+
|
| 222 |
+
previews = []
|
| 223 |
+
downloads = []
|
| 224 |
+
|
| 225 |
+
for key in selection:
|
| 226 |
+
if key not in SESSIONS:
|
| 227 |
+
previews.append(None)
|
| 228 |
+
downloads.append(gr.DownloadButton(label=f"{MODEL_LABELS[key]} β not loaded",
|
| 229 |
+
visible=True, value=None))
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
if include_8x and MODEL_SCALES[key] == 4:
|
| 234 |
+
result = upscale_8x_from_4x(input_img, key)
|
| 235 |
+
suffix = "Γ8"
|
| 236 |
+
else:
|
| 237 |
+
result = tile_upscale_model(input_img, key)
|
| 238 |
+
suffix = f"Γ{MODEL_SCALES[key]}"
|
| 239 |
+
|
| 240 |
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 241 |
+
result.save(tmp.name, format="PNG")
|
| 242 |
+
tmp.close()
|
| 243 |
+
|
| 244 |
+
previews.append(result)
|
| 245 |
+
downloads.append(tmp.name)
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"[ERROR] {key}: {e}")
|
| 248 |
+
previews.append(None)
|
| 249 |
+
downloads.append(None)
|
| 250 |
+
|
| 251 |
+
# Pad to always return exactly 4 preview + 4 download values
|
| 252 |
+
while len(previews) < 4: previews.append(None)
|
| 253 |
+
while len(downloads) < 4: downloads.append(None)
|
| 254 |
+
|
| 255 |
+
return previews + downloads # 8-element list
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ---------------------------------------------------------------------------
|
| 259 |
+
# Gradio UI β side-by-side comparison layout
|
| 260 |
+
# ---------------------------------------------------------------------------
|
| 261 |
+
css = """
|
| 262 |
+
body { font-family: 'Segoe UI', sans-serif; }
|
| 263 |
+
|
| 264 |
+
.panel-title {
|
| 265 |
+
text-align: center;
|
| 266 |
+
font-weight: 700;
|
| 267 |
+
font-size: 0.85rem;
|
| 268 |
+
letter-spacing: 0.08em;
|
| 269 |
+
text-transform: uppercase;
|
| 270 |
+
margin-bottom: 4px;
|
| 271 |
+
color: #555;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
#run-btn {
|
| 275 |
+
background: linear-gradient(135deg, #1a1a2e, #16213e) !important;
|
| 276 |
+
color: #e2e2e2 !important;
|
| 277 |
+
font-size: 1rem !important;
|
| 278 |
+
font-weight: 600 !important;
|
| 279 |
+
border-radius: 8px !important;
|
| 280 |
+
padding: 12px 28px !important;
|
| 281 |
+
}
|
| 282 |
+
#run-btn:hover {
|
| 283 |
+
background: linear-gradient(135deg, #0f3460, #533483) !important;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
.dl-btn button {
|
| 287 |
+
background: #f0f4ff !important;
|
| 288 |
+
border: 1px solid #c5d0f5 !important;
|
| 289 |
+
color: #333 !important;
|
| 290 |
+
font-size: 0.78rem !important;
|
| 291 |
+
border-radius: 6px !important;
|
| 292 |
+
width: 100%;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.model-toggle label { font-size: 0.9rem; }
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
ALL_KEYS = ["esrgan_x2", "esrgan_x4", "srcnn_x4", "hresnet_x4"]
|
| 299 |
+
PANEL_KEYS = ALL_KEYS # order for the 4 comparison panels
|
| 300 |
+
|
| 301 |
+
with gr.Blocks(css=css, title="SpectraGAN β Multi-Model Comparison") as demo:
|
| 302 |
+
|
| 303 |
+
gr.Markdown("""
|
| 304 |
+
# πΌοΈ SpectraGAN β Multi-Model Upscaler Comparison
|
| 305 |
+
Upload an image, select models, and compare results side by side.
|
| 306 |
+
""")
|
| 307 |
+
|
| 308 |
+
# ββ Input row ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 309 |
+
with gr.Row():
|
| 310 |
+
inp_image = gr.Image(type="pil", label="Source Image", scale=2)
|
| 311 |
+
|
| 312 |
+
with gr.Column(scale=1):
|
| 313 |
+
gr.Markdown("### Models to compare")
|
| 314 |
+
chk_esrgan_x2 = gr.Checkbox(label="Real-ESRGAN Γ2", value=True, elem_classes="model-toggle")
|
| 315 |
+
chk_esrgan_x4 = gr.Checkbox(label="Real-ESRGAN Γ4", value=True, elem_classes="model-toggle")
|
| 316 |
+
chk_srcnn = gr.Checkbox(label="SRCNN Γ4", value=True, elem_classes="model-toggle")
|
| 317 |
+
chk_hresnet = gr.Checkbox(label="HResNet Γ4", value=True, elem_classes="model-toggle")
|
| 318 |
+
|
| 319 |
+
gr.Markdown("### Options")
|
| 320 |
+
chk_8x = gr.Checkbox(label="Also apply Γ8 post-resize on Γ4 models", value=False)
|
| 321 |
+
|
| 322 |
+
run_btn = gr.Button("β‘ Run Comparison", elem_id="run-btn")
|
| 323 |
+
|
| 324 |
+
# ββ Comparison grid ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 325 |
+
gr.Markdown("---")
|
| 326 |
+
gr.Markdown("## Results")
|
| 327 |
+
|
| 328 |
+
previews = []
|
| 329 |
+
dl_btns = []
|
| 330 |
+
|
| 331 |
+
with gr.Row():
|
| 332 |
+
for key in PANEL_KEYS:
|
| 333 |
+
with gr.Column():
|
| 334 |
+
gr.HTML(f'<div class="panel-title">{MODEL_LABELS[key]}</div>')
|
| 335 |
+
img_out = gr.Image(
|
| 336 |
+
type="pil",
|
| 337 |
+
label=MODEL_LABELS[key],
|
| 338 |
+
show_label=False,
|
| 339 |
+
height=320,
|
| 340 |
+
)
|
| 341 |
+
dl_out = gr.DownloadButton(
|
| 342 |
+
label="β¬ Download PNG",
|
| 343 |
+
elem_classes="dl-btn",
|
| 344 |
+
visible=True,
|
| 345 |
+
)
|
| 346 |
+
previews.append(img_out)
|
| 347 |
+
dl_btns.append(dl_out)
|
| 348 |
+
|
| 349 |
+
# ββ Wire up βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 350 |
+
run_btn.click(
|
| 351 |
+
fn=compare_models,
|
| 352 |
+
inputs=[
|
| 353 |
+
inp_image,
|
| 354 |
+
chk_esrgan_x2,
|
| 355 |
+
chk_esrgan_x4,
|
| 356 |
+
chk_srcnn,
|
| 357 |
+
chk_hresnet,
|
| 358 |
+
chk_8x,
|
| 359 |
+
],
|
| 360 |
+
outputs=previews + dl_btns, # 8 outputs: 4 images + 4 download buttons
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime # ONNX inference engine (CPU)
|
| 2 |
+
numpy # Array manipulation
|
| 3 |
+
Pillow # Image I/O
|
| 4 |
+
gradio>=4.0 # Web UI (4.x needed for DownloadButton stability)
|
| 5 |
+
requests # Google Drive model downloader
|
| 6 |
+
|
| 7 |
+
# --- Optional: needed only if you integrate SR3 (PyTorch diffusion model) ---
|
| 8 |
+
# torch # PyTorch inference for SR3
|
| 9 |
+
# torchvision # Required by SR3 repo
|