ParamAhuja commited on
Commit
3262d11
Β·
1 Parent(s): f376a33
Files changed (3) hide show
  1. README.md +175 -5
  2. app.py +363 -0
  3. requirements.txt +9 -0
README.md CHANGED
@@ -1,13 +1,183 @@
1
  ---
2
  title: SpectraGAN
3
- emoji: πŸ”₯
4
  colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: SpectraGAN
3
+ emoji: πŸ–ΌοΈ
4
  colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.31.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # πŸ–ΌοΈ SpectraGAN β€” Multi-Model Upscaler Comparison
14
+
15
+ A Gradio web app that lets you upscale an image with **multiple SR models simultaneously** and compare results side by side.
16
+
17
+ Supported models:
18
+
19
+ | Model | Architecture | Scale |
20
+ |-------|-------------|-------|
21
+ | Real-ESRGAN Γ—2 | GAN (residual-in-residual dense block) | Γ—2 |
22
+ | Real-ESRGAN Γ—4 | GAN (residual-in-residual dense block) | Γ—4 |
23
+ | SRCNN Γ—4 | Shallow 3-layer CNN | Γ—4 |
24
+ | HResNet Γ—4 | Deep residual network (EDSR-style) | Γ—4 |
25
+ | SR3 *(stub)* | Diffusion model | Γ—4 β€” see note below |
26
+
27
+ ---
28
+
29
+ ## πŸ“‘ Table of Contents
30
+
31
+ 1. [Features](#features)
32
+ 2. [Project Structure](#project-structure)
33
+ 3. [Prerequisites](#prerequisites)
34
+ 4. [Installation](#installation)
35
+ 5. [Adding Your ONNX Models](#adding-your-onnx-models)
36
+ 6. [Running Locally](#running-locally)
37
+ 7. [SR3 Integration Guide](#sr3-integration-guide)
38
+ 8. [Contributing](#contributing)
39
+ 9. [License](#license)
40
+
41
+ ---
42
+
43
+ ## ✨ Features
44
+
45
+ - **Side-by-side comparison** β€” run up to 4 models at once, results displayed in a 4-panel grid.
46
+ - **Selective execution** β€” toggle any model on/off before running; unchecked models are skipped.
47
+ - **Γ—8 post-resize** β€” optionally apply a bicubic Γ—2 pass on top of any Γ—4 result.
48
+ - **Tile-based inference** β€” large images are split into tiles matching each model's fixed input size, then stitched back together seamlessly.
49
+ - **Per-result download** β€” each panel has its own PNG download button.
50
+ - **Graceful degradation** β€” if a model file is missing (e.g. Drive ID not yet set), that panel is skipped without crashing the others.
51
+
52
+ ---
53
+
54
+ ## πŸ“ Project Structure
55
+
56
+ ```
57
+ spectragan/
58
+ β”œβ”€β”€ model/
59
+ β”‚ β”œβ”€β”€ Real-ESRGAN_x2plus.onnx # auto-downloaded
60
+ β”‚ β”œβ”€β”€ Real-ESRGAN-x4plus.onnx # auto-downloaded
61
+ β”‚ β”œβ”€β”€ SRCNN_x4.onnx # you provide β€” see below
62
+ β”‚ └── HResNet_x4.onnx # you provide β€” see below
63
+ β”œβ”€β”€ app.py
64
+ β”œβ”€β”€ requirements.txt
65
+ └── README.md
66
+ ```
67
+
68
+ ---
69
+
70
+ ## βš™οΈ Prerequisites
71
+
72
+ - Python 3.10+
73
+ - `git`
74
+ - A terminal / command prompt
75
+
76
+ ---
77
+
78
+ ## πŸ”§ Installation
79
+
80
+ ```bash
81
+ git clone https://github.com/ParamAhuja/SpectraGAN.git
82
+ cd SpectraGAN
83
+ python -m venv .venv
84
+ source .venv/bin/activate # Linux/macOS
85
+ # .venv\Scripts\activate # Windows
86
+ pip install -r requirements.txt
87
+ ```
88
+
89
+ ---
90
+
91
+ ## πŸ—‚οΈ Adding Your ONNX Models
92
+
93
+ The Real-ESRGAN weights are downloaded automatically from Google Drive on first run.
94
+
95
+ For **SRCNN** and **HResNet** you need to:
96
+
97
+ 1. Export your trained PyTorch model to ONNX:
98
+
99
+ ```python
100
+ import torch
101
+
102
+ # SRCNN example
103
+ from srcnn import SRCNN
104
+ model = SRCNN()
105
+ model.load_state_dict(torch.load("srcnn.pth"))
106
+ model.eval()
107
+
108
+ dummy = torch.randn(1, 3, 128, 128)
109
+ torch.onnx.export(
110
+ model, dummy, "SRCNN_x4.onnx",
111
+ input_names=["input"], output_names=["output"],
112
+ dynamic_axes={"input": {2: "H", 3: "W"}, "output": {2: "H", 3: "W"}}
113
+ )
114
+ ```
115
+
116
+ 2. Upload the `.onnx` file to Google Drive and set **"Anyone with the link can view"**.
117
+
118
+ 3. Copy the file ID from the share URL and update `DRIVE_IDS` in `app.py`:
119
+
120
+ ```python
121
+ DRIVE_IDS = {
122
+ ...
123
+ "srcnn_x4": "YOUR_SRCNN_DRIVE_FILE_ID_HERE",
124
+ "hresnet_x4": "YOUR_HRESNET_DRIVE_FILE_ID_HERE",
125
+ }
126
+ ```
127
+
128
+ ---
129
+
130
+ ## πŸš€ Running Locally
131
+
132
+ ```bash
133
+ python app.py
134
+ ```
135
+
136
+ Open `http://127.0.0.1:7860` in your browser.
137
+
138
+ ---
139
+
140
+ ## πŸŒ€ SR3 Integration Guide
141
+
142
+ SR3 (Super-Resolution via Repeated Refinement) is a **diffusion model** β€” it cannot be exported to a static ONNX graph because its inference involves a variable-length denoising loop.
143
+
144
+ To add SR3:
145
+
146
+ 1. Clone the reference implementation:
147
+ ```bash
148
+ git clone https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement
149
+ ```
150
+
151
+ 2. Place your trained checkpoint at `model/sr3_x4.pth`.
152
+
153
+ 3. Add `torch` and `torchvision` to `requirements.txt`.
154
+
155
+ 4. Write a wrapper in `app.py`:
156
+ ```python
157
+ def run_sr3(input_img: Image.Image) -> Image.Image:
158
+ # load config + model, run the denoising loop, return result
159
+ ...
160
+ ```
161
+
162
+ 5. Add `"sr3_x4"` to the `PANEL_KEYS` list and wire `run_sr3` into `compare_models`.
163
+
164
+ ---
165
+
166
+ ## 🀝 Contributing
167
+
168
+ Pull requests welcome. Please open an issue first to discuss significant changes.
169
+
170
+ ---
171
+
172
+ ## πŸ“„ License
173
+
174
+ Apache 2.0 β€” see `LICENSE`.
175
+
176
+ ---
177
+
178
+ ## πŸ‘€ Author & Credits
179
+
180
+ - Real-ESRGAN by [xinntao](https://github.com/xinntao/Real-ESRGAN)
181
+ - SRCNN by Dong et al. (2014)
182
+ - HResNet / EDSR by Lim et al. (2017)
183
+ - SR3 by Ho et al. (2022) β€” [paper](https://arxiv.org/abs/2104.07636)
app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import uuid
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import tempfile
9
+ import requests
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Directory & model paths
13
+ # ---------------------------------------------------------------------------
14
+ MODEL_DIR = "model"
15
+ os.makedirs(MODEL_DIR, exist_ok=True)
16
+
17
+ MODEL_PATHS = {
18
+ "esrgan_x2": os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx"),
19
+ "esrgan_x4": os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx"),
20
+ "srcnn_x4": os.path.join(MODEL_DIR, "SRCNN_x4.onnx"),
21
+ "hresnet_x4": os.path.join(MODEL_DIR, "HResNet_x4.onnx"),
22
+ }
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Google Drive file IDs
26
+ # TODO: Replace the SRCNN / HResNet IDs with your own uploaded ONNX exports.
27
+ # Steps to export SRCNN to ONNX:
28
+ # import torch; from srcnn import SRCNN
29
+ # model = SRCNN(); model.load_state_dict(torch.load("srcnn.pth"))
30
+ # dummy = torch.randn(1, 3, 128, 128)
31
+ # torch.onnx.export(model, dummy, "SRCNN_x4.onnx",
32
+ # input_names=["input"], output_names=["output"],
33
+ # dynamic_axes={"input":{2:"H",3:"W"},"output":{2:"H",3:"W"}})
34
+ # Same pattern applies for HResNet / EDSR.
35
+ # ---------------------------------------------------------------------------
36
+ DRIVE_IDS = {
37
+ "esrgan_x2": "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6",
38
+ "esrgan_x4": "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6",
39
+ "srcnn_x4": "YOUR_SRCNN_DRIVE_FILE_ID_HERE", # <-- replace
40
+ "hresnet_x4": "YOUR_HRESNET_DRIVE_FILE_ID_HERE", # <-- replace
41
+ }
42
+
43
+ # Scale factor each model produces
44
+ MODEL_SCALES = {
45
+ "esrgan_x2": 2,
46
+ "esrgan_x4": 4,
47
+ "srcnn_x4": 4,
48
+ "hresnet_x4": 4,
49
+ }
50
+
51
+ # Human-readable labels shown in the UI
52
+ MODEL_LABELS = {
53
+ "esrgan_x2": "Real-ESRGAN Γ—2",
54
+ "esrgan_x4": "Real-ESRGAN Γ—4",
55
+ "srcnn_x4": "SRCNN Γ—4",
56
+ "hresnet_x4": "HResNet Γ—4",
57
+ }
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # SR3 NOTE
61
+ # ---------------------------------------------------------------------------
62
+ # SR3 (Super-Resolution via Repeated Refinement, Ho et al. 2022) is a
63
+ # *diffusion model* and cannot be exported to ONNX in the same way as
64
+ # feed-forward CNNs. To integrate SR3:
65
+ # 1. Clone the official repo: https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement
66
+ # 2. Place your checkpoint in model/sr3_x4.pth
67
+ # 3. Add a `run_sr3(image: Image) -> Image` function that runs the
68
+ # denoising loop using PyTorch directly (add `torch` to requirements).
69
+ # 4. Wire `run_sr3` into `upscale_one_model` for key "sr3_x4".
70
+ # SR3 is intentionally omitted from the ONNX pipeline to avoid misleading
71
+ # model shape assumptions.
72
+ # ---------------------------------------------------------------------------
73
+
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # Google Drive downloader
77
+ # ---------------------------------------------------------------------------
78
+ def download_from_drive(file_id: str, dest_path: str):
79
+ URL = "https://drive.google.com/uc?export=download"
80
+ session = requests.Session()
81
+ response = session.get(URL, params={"id": file_id}, stream=True)
82
+ token = None
83
+ for key, value in response.cookies.items():
84
+ if key.startswith("download_warning"):
85
+ token = value
86
+ break
87
+ if token:
88
+ response = session.get(URL, params={"id": file_id, "confirm": token}, stream=True)
89
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
90
+ with open(dest_path, "wb") as f:
91
+ for chunk in response.iter_content(chunk_size=32768):
92
+ if chunk:
93
+ f.write(chunk)
94
+ print(f"Downloaded β†’ {dest_path}")
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Download models if missing
99
+ # ---------------------------------------------------------------------------
100
+ for key, path in MODEL_PATHS.items():
101
+ if not os.path.isfile(path):
102
+ file_id = DRIVE_IDS[key]
103
+ if file_id.startswith("YOUR_"):
104
+ print(f"[WARN] Skipping {key}: Google Drive ID not set. "
105
+ "Update DRIVE_IDS in app.py with your ONNX export.")
106
+ else:
107
+ print(f"Downloading {MODEL_LABELS[key]} …")
108
+ download_from_drive(file_id, path)
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Load ONNX sessions (only for models that have been downloaded)
113
+ # ---------------------------------------------------------------------------
114
+ sess_opts = ort.SessionOptions()
115
+ sess_opts.intra_op_num_threads = 2
116
+ sess_opts.inter_op_num_threads = 2
117
+
118
+ SESSIONS = {} # key β†’ ort.InferenceSession
119
+ INPUT_SHAPES = {} # key β†’ (H_in, W_in)
120
+
121
+ for key, path in MODEL_PATHS.items():
122
+ if os.path.isfile(path):
123
+ try:
124
+ sess = ort.InferenceSession(
125
+ path,
126
+ sess_options=sess_opts,
127
+ providers=["CPUExecutionProvider"]
128
+ )
129
+ meta = sess.get_inputs()[0]
130
+ shape = tuple(meta.shape)
131
+ # shape is (1, 3, H, W) for fixed-size models, or (1,3,None,None) for dynamic
132
+ h = int(shape[2]) if shape[2] is not None and str(shape[2]).isdigit() else 128
133
+ w = int(shape[3]) if shape[3] is not None and str(shape[3]).isdigit() else 128
134
+ SESSIONS[key] = (sess, meta)
135
+ INPUT_SHAPES[key] = (h, w)
136
+ print(f"Loaded {MODEL_LABELS[key]} tile={h}Γ—{w}")
137
+ except Exception as e:
138
+ print(f"[ERROR] Could not load {key}: {e}")
139
+ else:
140
+ print(f"[INFO] {key} not available β€” will be skipped in comparisons.")
141
+
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # Tile-based upscale for a single ONNX model
145
+ # ---------------------------------------------------------------------------
146
+ def run_onnx_tile(sess, meta, tile_np: np.ndarray) -> np.ndarray:
147
+ """Run one tile through any ONNX session. tile_np is HWC float32 [0,1]."""
148
+ patch = np.transpose(tile_np, (2, 0, 1))[None, ...] # NCHW
149
+ out = sess.run(None, {meta.name: patch})[0]
150
+ out = np.squeeze(out, axis=0)
151
+ return np.transpose(out, (1, 2, 0)) # back to HWC
152
+
153
+
154
+ def tile_upscale_model(input_img: Image.Image, key: str, max_dim: int = 1024) -> Image.Image:
155
+ """
156
+ Upscale *input_img* using the ONNX model identified by *key*.
157
+ Returns a PIL Image at the upscaled resolution.
158
+ """
159
+ if key not in SESSIONS:
160
+ raise ValueError(f"Model '{key}' is not loaded. Check the Drive ID / path.")
161
+
162
+ sess, meta = SESSIONS[key]
163
+ H_in, W_in = INPUT_SHAPES[key]
164
+ scale = MODEL_SCALES[key]
165
+
166
+ # Optionally cap input size to avoid OOM on large images
167
+ w, h = input_img.size
168
+ if w > max_dim or h > max_dim:
169
+ factor = max_dim / float(max(w, h))
170
+ input_img = input_img.resize((int(w * factor), int(h * factor)), Image.LANCZOS)
171
+
172
+ arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0
173
+ h_orig, w_orig, _ = arr.shape
174
+
175
+ tiles_h = math.ceil(h_orig / H_in)
176
+ tiles_w = math.ceil(w_orig / W_in)
177
+ pad_h = tiles_h * H_in - h_orig
178
+ pad_w = tiles_w * W_in - w_orig
179
+
180
+ arr_padded = np.pad(arr, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
181
+ out_arr = np.zeros((tiles_h * H_in * scale, tiles_w * W_in * scale, 3), dtype=np.float32)
182
+
183
+ for i in range(tiles_h):
184
+ for j in range(tiles_w):
185
+ y0, x0 = i * H_in, j * W_in
186
+ tile = arr_padded[y0:y0 + H_in, x0:x0 + W_in, :]
187
+ up_tile = run_onnx_tile(sess, meta, tile)
188
+ oy0, ox0 = i * H_in * scale, j * W_in * scale
189
+ out_arr[oy0:oy0 + H_in * scale, ox0:ox0 + W_in * scale, :] = up_tile
190
+
191
+ final = np.clip(out_arr[0:h_orig * scale, 0:w_orig * scale, :], 0.0, 1.0)
192
+ return Image.fromarray((final * 255.0).round().astype(np.uint8))
193
+
194
+
195
+ def upscale_8x_from_4x(input_img: Image.Image, key: str) -> Image.Image:
196
+ """Run Γ—4 model, then bicubic Γ—2 to reach Γ—8."""
197
+ img_4x = tile_upscale_model(input_img, key)
198
+ w, h = input_img.size
199
+ return img_4x.resize((w * 8, h * 8), Image.LANCZOS)
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # Core comparison function (called by the Gradio button)
204
+ # ---------------------------------------------------------------------------
205
+ def compare_models(
206
+ input_img: Image.Image,
207
+ use_esrgan_x2: bool,
208
+ use_esrgan_x4: bool,
209
+ use_srcnn: bool,
210
+ use_hresnet: bool,
211
+ include_8x: bool,
212
+ ):
213
+ if input_img is None:
214
+ return [None] * 8 # 4 preview + 4 download slots
215
+
216
+ selection = []
217
+ if use_esrgan_x2: selection.append("esrgan_x2")
218
+ if use_esrgan_x4: selection.append("esrgan_x4")
219
+ if use_srcnn: selection.append("srcnn_x4")
220
+ if use_hresnet: selection.append("hresnet_x4")
221
+
222
+ previews = []
223
+ downloads = []
224
+
225
+ for key in selection:
226
+ if key not in SESSIONS:
227
+ previews.append(None)
228
+ downloads.append(gr.DownloadButton(label=f"{MODEL_LABELS[key]} – not loaded",
229
+ visible=True, value=None))
230
+ continue
231
+
232
+ try:
233
+ if include_8x and MODEL_SCALES[key] == 4:
234
+ result = upscale_8x_from_4x(input_img, key)
235
+ suffix = "Γ—8"
236
+ else:
237
+ result = tile_upscale_model(input_img, key)
238
+ suffix = f"Γ—{MODEL_SCALES[key]}"
239
+
240
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
241
+ result.save(tmp.name, format="PNG")
242
+ tmp.close()
243
+
244
+ previews.append(result)
245
+ downloads.append(tmp.name)
246
+ except Exception as e:
247
+ print(f"[ERROR] {key}: {e}")
248
+ previews.append(None)
249
+ downloads.append(None)
250
+
251
+ # Pad to always return exactly 4 preview + 4 download values
252
+ while len(previews) < 4: previews.append(None)
253
+ while len(downloads) < 4: downloads.append(None)
254
+
255
+ return previews + downloads # 8-element list
256
+
257
+
258
+ # ---------------------------------------------------------------------------
259
+ # Gradio UI – side-by-side comparison layout
260
+ # ---------------------------------------------------------------------------
261
+ css = """
262
+ body { font-family: 'Segoe UI', sans-serif; }
263
+
264
+ .panel-title {
265
+ text-align: center;
266
+ font-weight: 700;
267
+ font-size: 0.85rem;
268
+ letter-spacing: 0.08em;
269
+ text-transform: uppercase;
270
+ margin-bottom: 4px;
271
+ color: #555;
272
+ }
273
+
274
+ #run-btn {
275
+ background: linear-gradient(135deg, #1a1a2e, #16213e) !important;
276
+ color: #e2e2e2 !important;
277
+ font-size: 1rem !important;
278
+ font-weight: 600 !important;
279
+ border-radius: 8px !important;
280
+ padding: 12px 28px !important;
281
+ }
282
+ #run-btn:hover {
283
+ background: linear-gradient(135deg, #0f3460, #533483) !important;
284
+ }
285
+
286
+ .dl-btn button {
287
+ background: #f0f4ff !important;
288
+ border: 1px solid #c5d0f5 !important;
289
+ color: #333 !important;
290
+ font-size: 0.78rem !important;
291
+ border-radius: 6px !important;
292
+ width: 100%;
293
+ }
294
+
295
+ .model-toggle label { font-size: 0.9rem; }
296
+ """
297
+
298
+ ALL_KEYS = ["esrgan_x2", "esrgan_x4", "srcnn_x4", "hresnet_x4"]
299
+ PANEL_KEYS = ALL_KEYS # order for the 4 comparison panels
300
+
301
+ with gr.Blocks(css=css, title="SpectraGAN β€” Multi-Model Comparison") as demo:
302
+
303
+ gr.Markdown("""
304
+ # πŸ–ΌοΈ SpectraGAN β€” Multi-Model Upscaler Comparison
305
+ Upload an image, select models, and compare results side by side.
306
+ """)
307
+
308
+ # ── Input row ──────────────────────────────────────────────────────────
309
+ with gr.Row():
310
+ inp_image = gr.Image(type="pil", label="Source Image", scale=2)
311
+
312
+ with gr.Column(scale=1):
313
+ gr.Markdown("### Models to compare")
314
+ chk_esrgan_x2 = gr.Checkbox(label="Real-ESRGAN Γ—2", value=True, elem_classes="model-toggle")
315
+ chk_esrgan_x4 = gr.Checkbox(label="Real-ESRGAN Γ—4", value=True, elem_classes="model-toggle")
316
+ chk_srcnn = gr.Checkbox(label="SRCNN Γ—4", value=True, elem_classes="model-toggle")
317
+ chk_hresnet = gr.Checkbox(label="HResNet Γ—4", value=True, elem_classes="model-toggle")
318
+
319
+ gr.Markdown("### Options")
320
+ chk_8x = gr.Checkbox(label="Also apply Γ—8 post-resize on Γ—4 models", value=False)
321
+
322
+ run_btn = gr.Button("⚑ Run Comparison", elem_id="run-btn")
323
+
324
+ # ── Comparison grid ────────────────────────────────────────────────────
325
+ gr.Markdown("---")
326
+ gr.Markdown("## Results")
327
+
328
+ previews = []
329
+ dl_btns = []
330
+
331
+ with gr.Row():
332
+ for key in PANEL_KEYS:
333
+ with gr.Column():
334
+ gr.HTML(f'<div class="panel-title">{MODEL_LABELS[key]}</div>')
335
+ img_out = gr.Image(
336
+ type="pil",
337
+ label=MODEL_LABELS[key],
338
+ show_label=False,
339
+ height=320,
340
+ )
341
+ dl_out = gr.DownloadButton(
342
+ label="⬇ Download PNG",
343
+ elem_classes="dl-btn",
344
+ visible=True,
345
+ )
346
+ previews.append(img_out)
347
+ dl_btns.append(dl_out)
348
+
349
+ # ── Wire up ─────────────────────────────────────────────────────────────
350
+ run_btn.click(
351
+ fn=compare_models,
352
+ inputs=[
353
+ inp_image,
354
+ chk_esrgan_x2,
355
+ chk_esrgan_x4,
356
+ chk_srcnn,
357
+ chk_hresnet,
358
+ chk_8x,
359
+ ],
360
+ outputs=previews + dl_btns, # 8 outputs: 4 images + 4 download buttons
361
+ )
362
+
363
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ onnxruntime # ONNX inference engine (CPU)
2
+ numpy # Array manipulation
3
+ Pillow # Image I/O
4
+ gradio>=4.0 # Web UI (4.x needed for DownloadButton stability)
5
+ requests # Google Drive model downloader
6
+
7
+ # --- Optional: needed only if you integrate SR3 (PyTorch diffusion model) ---
8
+ # torch # PyTorch inference for SR3
9
+ # torchvision # Required by SR3 repo