ParamDev commited on
Commit
8ab66f8
Β·
verified Β·
1 Parent(s): b0549dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -376
app.py CHANGED
@@ -5,309 +5,156 @@ import onnxruntime as ort
5
  from PIL import Image
6
  import gradio as gr
7
  import tempfile
8
- from huggingface_hub import hf_hub_download, HfApi
 
 
 
9
 
10
  # ---------------------------------------------------------------------------
11
- # CONFIG β€” only these two lines ever need changing
12
  # ---------------------------------------------------------------------------
13
- HF_REPO_ID = "ParamDev/spectraGAN" # your HF model repo
14
- HF_TOKEN = os.environ.get("HF_TOKEN") # set this in Space β†’ Settings β†’ Secrets
15
 
16
- # ---------------------------------------------------------------------------
17
- # Filenames used in the HF repo
18
- # ---------------------------------------------------------------------------
19
- HF_FILENAMES = {
20
- "esrgan_x2": "Real-ESRGAN_x2plus.onnx",
21
- "esrgan_x4": "Real-ESRGAN-x4plus.onnx",
22
- "srcnn_x4": "SRCNN_x4.onnx",
23
- "hresnet_x4": "HResNet_x4.onnx",
24
  }
25
 
26
- MODEL_SCALES = {
27
- "esrgan_x2": 2,
28
- "esrgan_x4": 4,
29
- "srcnn_x4": 4,
30
- "hresnet_x4": 4,
31
- }
32
 
33
  MODEL_LABELS = {
34
- "esrgan_x2": "Real-ESRGAN Γ—2",
35
  "esrgan_x4": "Real-ESRGAN Γ—4",
36
  "srcnn_x4": "SRCNN Γ—4",
37
- "hresnet_x4": "HResNet Γ—4",
38
  }
39
 
40
- # Local scratch directory (writable in HF Spaces)
41
- CACHE_DIR = "/tmp/spectragan_models"
42
- os.makedirs(CACHE_DIR, exist_ok=True)
43
-
44
-
45
- # ===========================================================================
46
- # STEP 1 β€” ONNX export for SRCNN and HResNet
47
- # Only runs if the file is missing from both local cache and the HF repo.
48
- # Uses torch which is available in HF Spaces by default.
49
- # ===========================================================================
50
-
51
- def _export_srcnn_onnx(out_path: str):
52
- import torch
53
- import torch.nn as nn
54
-
55
- class SRCNN(nn.Module):
56
- def __init__(self):
57
- super().__init__()
58
- self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
59
- self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
60
- self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
61
- self.relu = nn.ReLU(inplace=True)
62
-
63
- def forward(self, x):
64
- return self.conv3(self.relu(self.conv2(self.relu(self.conv1(x)))))
65
-
66
- class SRCNNx4(nn.Module):
67
- """Bicubic Γ—4 upsample β†’ SRCNN refinement."""
68
- def __init__(self):
69
- super().__init__()
70
- self.srcnn = SRCNN()
71
-
72
- def forward(self, x):
73
- up = torch.nn.functional.interpolate(
74
- x, scale_factor=4, mode="bicubic", align_corners=False
75
- )
76
- return self.srcnn(up)
77
-
78
- model = SRCNNx4()
79
-
80
- # Try to load pretrained weights; fall back to random init if unavailable
81
- weights_url = (
82
- "https://github.com/yjn870/SRCNN-pytorch/raw/master/models/srcnn_x4.pth"
83
- )
84
- weights_path = os.path.join(CACHE_DIR, "srcnn_x4.pth")
85
- if not os.path.exists(weights_path):
86
- try:
87
- import urllib.request
88
- print(" Downloading SRCNN pretrained weights …")
89
- urllib.request.urlretrieve(weights_url, weights_path)
90
- except Exception as e:
91
- print(f" [WARN] Could not download SRCNN weights ({e}). Using random init.")
92
- weights_path = None
93
-
94
- if weights_path and os.path.exists(weights_path):
95
- state = torch.load(weights_path, map_location="cpu")
96
- try:
97
- model.srcnn.load_state_dict(state, strict=True)
98
- print(" SRCNN pretrained weights loaded βœ“")
99
- except RuntimeError:
100
- model.srcnn.load_state_dict(state, strict=False)
101
- print(" SRCNN weights loaded (partial) βœ“")
102
-
103
- model.eval()
104
- dummy = torch.zeros(1, 3, 128, 128)
105
- torch.onnx.export(
106
- model, dummy, out_path,
107
- opset_version=17,
108
- input_names=["input"], output_names=["output"],
109
- dynamic_axes={"input": {2: "H", 3: "W"}, "output": {2: "H", 3: "W"}},
110
- )
111
- print(f" Exported SRCNN β†’ {out_path}")
112
-
113
-
114
- def _export_hresnet_onnx(out_path: str):
115
- import torch
116
- import torch.nn as nn
117
-
118
- class ResBlock(nn.Module):
119
- def __init__(self, n_feats):
120
- super().__init__()
121
- self.body = nn.Sequential(
122
- nn.Conv2d(n_feats, n_feats, 3, padding=1),
123
- nn.ReLU(inplace=True),
124
- nn.Conv2d(n_feats, n_feats, 3, padding=1),
125
- )
126
- def forward(self, x):
127
- return x + self.body(x)
128
-
129
- class EDSR(nn.Module):
130
- """EDSR-baseline: 16 residual blocks, 64 channels, Γ—4."""
131
- def __init__(self):
132
- super().__init__()
133
- n_feats = 64
134
- self.head = nn.Conv2d(3, n_feats, 3, padding=1)
135
- self.body = nn.Sequential(*[ResBlock(n_feats) for _ in range(16)])
136
- self.body_tail = nn.Conv2d(n_feats, n_feats, 3, padding=1)
137
- # Γ—4 via two Γ—2 pixel-shuffle stages
138
- self.tail = nn.Sequential(
139
- nn.Conv2d(n_feats, n_feats * 4, 3, padding=1),
140
- nn.PixelShuffle(2),
141
- nn.Conv2d(n_feats, n_feats * 4, 3, padding=1),
142
- nn.PixelShuffle(2),
143
- nn.Conv2d(n_feats, 3, 3, padding=1),
144
- )
145
-
146
- def forward(self, x):
147
- x = self.head(x)
148
- res = self.body_tail(self.body(x))
149
- return self.tail(x + res)
150
-
151
- model = EDSR()
152
-
153
- # Try to load pretrained weights from HuggingFace
154
- weights_url = (
155
- "https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model_4x.pt"
156
- )
157
- weights_path = os.path.join(CACHE_DIR, "edsr_x4.pt")
158
- if not os.path.exists(weights_path):
159
- try:
160
- import urllib.request
161
- print(" Downloading EDSR pretrained weights from HuggingFace …")
162
- urllib.request.urlretrieve(weights_url, weights_path)
163
- except Exception as e:
164
- print(f" [WARN] Could not download EDSR weights ({e}). Using random init.")
165
- weights_path = None
166
-
167
- if weights_path and os.path.exists(weights_path):
168
- state = torch.load(weights_path, map_location="cpu")
169
- if "model" in state: state = state["model"]
170
- if "state_dict" in state: state = state["state_dict"]
171
- state = {k.replace("module.", ""): v for k, v in state.items()}
172
- try:
173
- model.load_state_dict(state, strict=True)
174
- print(" EDSR pretrained weights loaded βœ“")
175
- except RuntimeError:
176
- model.load_state_dict(state, strict=False)
177
- print(" EDSR weights loaded (partial) βœ“")
178
-
179
- model.eval()
180
- dummy = torch.zeros(1, 3, 128, 128)
181
- torch.onnx.export(
182
- model, dummy, out_path,
183
- opset_version=17,
184
- input_names=["input"], output_names=["output"],
185
- dynamic_axes={"input": {2: "H", 3: "W"}, "output": {2: "H", 3: "W"}},
186
- )
187
- print(f" Exported HResNet β†’ {out_path}")
188
-
189
-
190
- # Map model keys that need local export to their export functions
191
- EXPORTERS = {
192
- "srcnn_x4": _export_srcnn_onnx,
193
- "hresnet_x4": _export_hresnet_onnx,
194
  }
195
 
196
 
197
- def _upload_to_hub(local_path: str, filename: str):
198
- """Upload a file to the HF repo so future Space restarts skip the export."""
199
- if not HF_TOKEN:
200
- print(f" [INFO] HF_TOKEN not set β€” skipping upload of {filename}.")
201
- print(f" Set HF_TOKEN in Space Secrets to persist generated models.")
202
- return
203
- try:
204
- api = HfApi()
205
- api.upload_file(
206
- path_or_fileobj=local_path,
207
- path_in_repo=filename,
208
- repo_id=HF_REPO_ID,
209
- repo_type="model",
210
- token=HF_TOKEN,
211
- )
212
- print(f" Uploaded {filename} β†’ {HF_REPO_ID} βœ“")
213
- except Exception as e:
214
- print(f" [WARN] Upload failed for {filename}: {e}")
215
-
216
-
217
  # ===========================================================================
218
- # STEP 2 β€” Fetch or generate all models
219
- # Priority: HF Hub cache β†’ local export β†’ upload back to Hub
220
  # ===========================================================================
 
 
 
 
 
 
 
221
 
222
- MODEL_PATHS = {} # key β†’ local file path ready for onnxruntime
223
-
224
- for key, filename in HF_FILENAMES.items():
225
- local_path = os.path.join(CACHE_DIR, filename)
226
-
227
- # Already cached locally from a previous run this session
228
- if os.path.exists(local_path):
229
- print(f"[CACHE] {MODEL_LABELS[key]} already in /tmp cache.")
230
- MODEL_PATHS[key] = local_path
231
- continue
232
-
233
- # Try downloading from HF Hub first
234
- try:
235
- print(f"Fetching {MODEL_LABELS[key]} from HF Hub …")
236
- dl_path = hf_hub_download(
237
- repo_id=HF_REPO_ID,
238
- filename=filename,
239
- repo_type="model",
240
- token=HF_TOKEN,
241
- local_dir=CACHE_DIR,
242
- )
243
- MODEL_PATHS[key] = dl_path
244
- print(f" βœ“ {dl_path}")
245
- continue
246
- except Exception as e:
247
- print(f" Not found on Hub ({e})")
248
-
249
- # Not on Hub β€” export it locally if we have an exporter for it
250
- if key in EXPORTERS:
251
- print(f"Generating {MODEL_LABELS[key]} via local export …")
252
- try:
253
- EXPORTERS[key](local_path)
254
- MODEL_PATHS[key] = local_path
255
- # Upload back to Hub so next restart skips this step
256
- print(f" Uploading to Hub for future restarts …")
257
- _upload_to_hub(local_path, filename)
258
- except Exception as e:
259
- print(f" [ERROR] Export failed for {key}: {e}")
260
- else:
261
- print(f" [WARN] {key} not on Hub and no exporter available β€” skipping.")
262
 
263
 
264
  # ===========================================================================
265
- # STEP 3 β€” Load ONNX sessions
266
  # ===========================================================================
267
 
268
  sess_opts = ort.SessionOptions()
269
  sess_opts.intra_op_num_threads = 2
270
  sess_opts.inter_op_num_threads = 2
271
 
272
- SESSIONS = {}
273
- INPUT_SHAPES = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- for key, path in MODEL_PATHS.items():
 
 
 
 
 
 
 
 
 
 
 
 
276
  try:
277
- sess = ort.InferenceSession(
278
- path,
279
- sess_options=sess_opts,
280
- providers=["CPUExecutionProvider"],
281
- )
282
- meta = sess.get_inputs()[0]
283
- shape = tuple(meta.shape)
284
- h = int(shape[2]) if isinstance(shape[2], int) else 128
285
- w = int(shape[3]) if isinstance(shape[3], int) else 128
286
- SESSIONS[key] = (sess, meta)
287
- INPUT_SHAPES[key] = (h, w)
288
- print(f"Loaded {MODEL_LABELS[key]} tile={h}Γ—{w}")
 
289
  except Exception as e:
290
- print(f"[ERROR] Could not load {key}: {e}")
 
 
291
 
292
 
293
  # ===========================================================================
294
- # STEP 4 β€” Inference helpers
295
  # ===========================================================================
296
 
297
- def run_onnx_tile(sess, meta, tile_np: np.ndarray) -> np.ndarray:
298
- patch = np.transpose(tile_np, (2, 0, 1))[None, ...]
 
299
  out = sess.run(None, {meta.name: patch})[0]
300
- return np.transpose(np.squeeze(out, 0), (1, 2, 0))
301
-
302
-
303
- def tile_upscale_model(input_img: Image.Image, key: str, max_dim: int = 1024) -> Image.Image:
304
- if key not in SESSIONS:
305
- raise ValueError(f"Model '{key}' is not loaded.")
306
-
307
- sess, meta = SESSIONS[key]
308
- H_in, W_in = INPUT_SHAPES[key]
309
- scale = MODEL_SCALES[key]
310
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  w, h = input_img.size
312
  if w > max_dim or h > max_dim:
313
  factor = max_dim / float(max(w, h))
@@ -315,149 +162,189 @@ def tile_upscale_model(input_img: Image.Image, key: str, max_dim: int = 1024) ->
315
 
316
  arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0
317
  h_orig, w_orig, _ = arr.shape
318
- tiles_h = math.ceil(h_orig / H_in)
319
- tiles_w = math.ceil(w_orig / W_in)
320
 
321
- arr_padded = np.pad(
322
  arr,
323
- ((0, tiles_h * H_in - h_orig), (0, tiles_w * W_in - w_orig), (0, 0)),
324
  mode="reflect",
325
  )
326
- out_arr = np.zeros(
327
- (tiles_h * H_in * scale, tiles_w * W_in * scale, 3), dtype=np.float32
328
- )
329
  for i in range(tiles_h):
330
  for j in range(tiles_w):
331
- y0, x0 = i * H_in, j * W_in
332
- up_tile = run_onnx_tile(sess, meta, arr_padded[y0:y0+H_in, x0:x0+W_in])
333
- oy0, ox0 = i * H_in * scale, j * W_in * scale
334
- out_arr[oy0:oy0+H_in*scale, ox0:ox0+W_in*scale] = up_tile
335
 
336
- final = np.clip(out_arr[:h_orig*scale, :w_orig*scale], 0.0, 1.0)
 
 
 
 
 
 
 
 
 
337
  return Image.fromarray((final * 255.0).round().astype(np.uint8))
338
 
339
 
340
- def upscale_8x(input_img: Image.Image, key: str) -> Image.Image:
341
- img_4x = tile_upscale_model(input_img, key)
342
- w, h = input_img.size
343
- return img_4x.resize((w * 8, h * 8), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
 
346
  # ===========================================================================
347
- # STEP 5 β€” Gradio comparison callback
348
  # ===========================================================================
349
 
350
- def compare_models(
351
- input_img: Image.Image,
352
- use_esrgan_x2: bool,
353
- use_esrgan_x4: bool,
354
- use_srcnn: bool,
355
- use_hresnet: bool,
356
- include_8x: bool,
357
- ):
358
  if input_img is None:
359
- return [None] * 8
360
-
361
- wanted = {
362
- "esrgan_x2": use_esrgan_x2,
363
- "esrgan_x4": use_esrgan_x4,
364
- "srcnn_x4": use_srcnn,
365
- "hresnet_x4": use_hresnet,
366
- }
367
- selection = [k for k, v in wanted.items() if v]
368
- previews = []
369
- downloads = []
370
-
371
- for key in selection:
372
- if key not in SESSIONS:
373
- previews.append(None)
374
- downloads.append(None)
375
- continue
376
- try:
377
- result = (
378
- upscale_8x(input_img, key)
379
- if include_8x and MODEL_SCALES[key] == 4
380
- else tile_upscale_model(input_img, key)
381
- )
382
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
383
- result.save(tmp.name, format="PNG")
384
- tmp.close()
385
- previews.append(result)
386
- downloads.append(tmp.name)
387
- except Exception as e:
388
- print(f"[ERROR] {key}: {e}")
389
- previews.append(None)
390
- downloads.append(None)
391
 
392
- while len(previews) < 4: previews.append(None)
393
- while len(downloads) < 4: downloads.append(None)
394
- return previews + downloads
 
 
 
395
 
396
 
397
  # ===========================================================================
398
- # STEP 6 β€” Gradio UI
399
  # ===========================================================================
400
 
401
  css = """
402
- body { font-family: 'Segoe UI', sans-serif; }
403
- .panel-title {
404
- text-align: center; font-weight: 700; font-size: 0.85rem;
405
- letter-spacing: 0.08em; text-transform: uppercase;
406
- margin-bottom: 4px; color: #555;
 
 
407
  }
 
 
 
 
 
 
 
 
408
  #run-btn {
409
- background: linear-gradient(135deg, #1a1a2e, #16213e) !important;
410
- color: #e2e2e2 !important; font-size: 1rem !important;
411
- font-weight: 600 !important; border-radius: 8px !important;
412
- padding: 12px 28px !important;
 
 
 
 
 
413
  }
414
- #run-btn:hover { background: linear-gradient(135deg, #0f3460, #533483) !important; }
415
- .dl-btn button {
416
- background: #f0f4ff !important; border: 1px solid #c5d0f5 !important;
417
- color: #333 !important; font-size: 0.78rem !important;
418
- border-radius: 6px !important; width: 100%;
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  }
420
- .model-toggle label { font-size: 0.9rem; }
421
  """
422
 
423
- PANEL_KEYS = ["esrgan_x2", "esrgan_x4", "srcnn_x4", "hresnet_x4"]
 
 
 
 
 
 
424
 
425
- with gr.Blocks(css=css, title="SpectraGAN β€” Multi-Model Comparison") as demo:
426
 
427
- gr.Markdown("""
428
- # πŸ–ΌοΈ SpectraGAN β€” Multi-Model Upscaler Comparison
429
- Upload an image, select models, and compare results side by side.
 
 
430
  """)
431
 
432
- with gr.Row():
433
- inp_image = gr.Image(type="pil", label="Source Image", scale=2)
434
- with gr.Column(scale=1):
435
- gr.Markdown("### Models to compare")
436
- chk_esrgan_x2 = gr.Checkbox(label="Real-ESRGAN Γ—2", value=True, elem_classes="model-toggle")
437
- chk_esrgan_x4 = gr.Checkbox(label="Real-ESRGAN Γ—4", value=True, elem_classes="model-toggle")
438
- chk_srcnn = gr.Checkbox(label="SRCNN Γ—4", value=True, elem_classes="model-toggle")
439
- chk_hresnet = gr.Checkbox(label="HResNet Γ—4", value=True, elem_classes="model-toggle")
440
- gr.Markdown("### Options")
441
- chk_8x = gr.Checkbox(label="Also apply Γ—8 post-resize on Γ—4 models", value=False)
442
- run_btn = gr.Button("⚑ Run Comparison", elem_id="run-btn")
443
-
444
- gr.Markdown("---\n## Results")
445
- previews = []
446
- dl_btns = []
447
-
448
- with gr.Row():
449
- for key in PANEL_KEYS:
450
- with gr.Column():
451
- gr.HTML(f'<div class="panel-title">{MODEL_LABELS[key]}</div>')
452
- previews.append(gr.Image(type="pil", show_label=False, height=320))
453
- dl_btns.append(gr.DownloadButton(
454
- label="⬇ Download PNG", elem_classes="dl-btn", visible=True
455
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
  run_btn.click(
458
- fn=compare_models,
459
- inputs=[inp_image, chk_esrgan_x2, chk_esrgan_x4, chk_srcnn, chk_hresnet, chk_8x],
460
- outputs=previews + dl_btns,
461
  )
462
 
463
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
5
  from PIL import Image
6
  import gradio as gr
7
  import tempfile
8
+ import gdown
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
 
13
  # ---------------------------------------------------------------------------
14
+ # Paths & constants
15
  # ---------------------------------------------------------------------------
16
+ CACHE_DIR = "/tmp/spectragan"
17
+ os.makedirs(CACHE_DIR, exist_ok=True)
18
 
19
+ # Google Drive IDs for ESRGAN ONNX files
20
+ DRIVE_IDS = {
21
+ "esrgan_x4": "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6",
22
+ "hresnet_x4": "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6", # placeholder = ESRGAN x2
 
 
 
 
23
  }
24
 
25
+ # SRCNN .pth lives in the same HF Space repo (uploaded by user)
26
+ SRCNN_PTH = os.path.join(os.path.dirname(__file__), "srcnn_x4.pth")
 
 
 
 
27
 
28
  MODEL_LABELS = {
 
29
  "esrgan_x4": "Real-ESRGAN Γ—4",
30
  "srcnn_x4": "SRCNN Γ—4",
31
+ "hresnet_x4": "HResNet Γ—4", # placeholder until real weights available
32
  }
33
 
34
+ MODEL_SCALES = {
35
+ "esrgan_x4": 4,
36
+ "srcnn_x4": 4,
37
+ "hresnet_x4": 2, # underlying model is ESRGAN x2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  }
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # ===========================================================================
42
+ # SRCNN architecture (3-layer, matches the .pth weights)
 
43
  # ===========================================================================
44
+ class SRCNN(nn.Module):
45
+ def __init__(self, num_channels: int = 3):
46
+ super().__init__()
47
+ self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
48
+ self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
49
+ self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
50
+ self.relu = nn.ReLU(inplace=True)
51
 
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ return self.conv3(self.relu(self.conv2(self.relu(self.conv1(x)))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  # ===========================================================================
57
+ # Model loading
58
  # ===========================================================================
59
 
60
  sess_opts = ort.SessionOptions()
61
  sess_opts.intra_op_num_threads = 2
62
  sess_opts.inter_op_num_threads = 2
63
 
64
+ ONNX_SESSIONS = {} # key β†’ (ort.InferenceSession, input_meta)
65
+ SRCNN_MODEL = None
66
+
67
+
68
+ def _load_esrgan_onnx(key: str):
69
+ """Download ESRGAN ONNX from Drive (gdown handles confirmation pages)."""
70
+ filename = f"{key}.onnx"
71
+ dest = os.path.join(CACHE_DIR, filename)
72
+ if not os.path.exists(dest):
73
+ print(f"Downloading {MODEL_LABELS[key]} from Drive …")
74
+ gdown.download(id=DRIVE_IDS[key], output=dest, quiet=False, fuzzy=True)
75
+ if os.path.exists(dest):
76
+ sess = ort.InferenceSession(dest, sess_options=sess_opts,
77
+ providers=["CPUExecutionProvider"])
78
+ meta = sess.get_inputs()[0]
79
+ ONNX_SESSIONS[key] = (sess, meta)
80
+ print(f"Loaded {MODEL_LABELS[key]} βœ“")
81
+ else:
82
+ print(f"[ERROR] Could not load {key} β€” file missing after download.")
83
+
84
 
85
+ def _load_srcnn_pth():
86
+ """Load SRCNN directly from the .pth file in the Space repo."""
87
+ global SRCNN_MODEL
88
+ if not os.path.exists(SRCNN_PTH):
89
+ print(f"[WARN] srcnn_x4.pth not found at {SRCNN_PTH} β€” SRCNN will be skipped.")
90
+ return
91
+ model = SRCNN(num_channels=3)
92
+ state = torch.load(SRCNN_PTH, map_location="cpu")
93
+ # Unwrap common checkpoint wrappers
94
+ for key in ("model", "state_dict", "params"):
95
+ if isinstance(state, dict) and key in state:
96
+ state = state[key]
97
+ state = {k.replace("module.", ""): v for k, v in state.items()}
98
  try:
99
+ model.load_state_dict(state, strict=True)
100
+ except RuntimeError:
101
+ model.load_state_dict(state, strict=False)
102
+ print("[WARN] SRCNN loaded with strict=False (minor key mismatch).")
103
+ model.eval()
104
+ SRCNN_MODEL = model
105
+ print("Loaded SRCNN Γ—4 from .pth βœ“")
106
+
107
+
108
+ # Boot-time loading
109
+ for k in ("esrgan_x4", "hresnet_x4"):
110
+ try:
111
+ _load_esrgan_onnx(k)
112
  except Exception as e:
113
+ print(f"[ERROR] {k}: {e}")
114
+
115
+ _load_srcnn_pth()
116
 
117
 
118
  # ===========================================================================
119
+ # Inference helpers
120
  # ===========================================================================
121
 
122
+ def _onnx_tile(sess, meta, tile: np.ndarray) -> np.ndarray:
123
+ """HWC float32 [0,1] β†’ HWC float32 [0,1]."""
124
+ patch = tile.transpose(2, 0, 1)[None, ...]
125
  out = sess.run(None, {meta.name: patch})[0]
126
+ return out.squeeze(0).transpose(1, 2, 0)
127
+
128
+
129
+ def _srcnn_tile(tile: np.ndarray, scale: int = 4) -> np.ndarray:
130
+ """Bicubic upsample β†’ SRCNN refine. HWC float32 [0,1] β†’ HWC float32."""
131
+ t = torch.from_numpy(tile.transpose(2, 0, 1)).unsqueeze(0) # NCHW
132
+ h, w = t.shape[2], t.shape[3]
133
+ up = F.interpolate(t, size=(h * scale, w * scale),
134
+ mode="bicubic", align_corners=False)
135
+ with torch.no_grad():
136
+ out = SRCNN_MODEL(up)
137
+ return out.squeeze(0).permute(1, 2, 0).numpy()
138
+
139
+
140
+ def upscale(input_img: Image.Image, model_key: str, max_dim: int = 1024) -> Image.Image:
141
+ """
142
+ Tile-based upscale dispatcher.
143
+ Works for both ONNX sessions (ESRGAN) and the torch SRCNN model.
144
+ """
145
+ # Guard
146
+ if model_key == "srcnn_x4" and SRCNN_MODEL is None:
147
+ raise RuntimeError("SRCNN model not loaded.")
148
+ if model_key in ("esrgan_x4", "hresnet_x4") and model_key not in ONNX_SESSIONS:
149
+ raise RuntimeError(f"{MODEL_LABELS[model_key]} not loaded.")
150
+
151
+ scale = MODEL_SCALES[model_key]
152
+
153
+ # Tile size: ESRGAN uses fixed 128Γ—128 LR tiles;
154
+ # SRCNN works on any size (we tile at 128 for consistency).
155
+ TILE = 128
156
+
157
+ # Cap input size
158
  w, h = input_img.size
159
  if w > max_dim or h > max_dim:
160
  factor = max_dim / float(max(w, h))
 
162
 
163
  arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0
164
  h_orig, w_orig, _ = arr.shape
165
+ tiles_h = math.ceil(h_orig / TILE)
166
+ tiles_w = math.ceil(w_orig / TILE)
167
 
168
+ arr_pad = np.pad(
169
  arr,
170
+ ((0, tiles_h * TILE - h_orig), (0, tiles_w * TILE - w_orig), (0, 0)),
171
  mode="reflect",
172
  )
173
+ out = np.zeros((tiles_h * TILE * scale, tiles_w * TILE * scale, 3), dtype=np.float32)
174
+
 
175
  for i in range(tiles_h):
176
  for j in range(tiles_w):
177
+ y0, x0 = i * TILE, j * TILE
178
+ tile = arr_pad[y0:y0 + TILE, x0:x0 + TILE]
 
 
179
 
180
+ if model_key == "srcnn_x4":
181
+ up_tile = _srcnn_tile(tile, scale=scale)
182
+ else:
183
+ sess, meta = ONNX_SESSIONS[model_key]
184
+ up_tile = _onnx_tile(sess, meta, tile)
185
+
186
+ oy0, ox0 = i * TILE * scale, j * TILE * scale
187
+ out[oy0:oy0 + TILE * scale, ox0:ox0 + TILE * scale] = up_tile
188
+
189
+ final = np.clip(out[:h_orig * scale, :w_orig * scale], 0.0, 1.0)
190
  return Image.fromarray((final * 255.0).round().astype(np.uint8))
191
 
192
 
193
+ def make_comparison_png(original: Image.Image, upscaled: Image.Image) -> str:
194
+ """
195
+ Save a side-by-side PNG (original | upscaled, same display height)
196
+ that the ImageSlider widget will use.
197
+ """
198
+ up_w, up_h = upscaled.size
199
+ orig_resized = original.resize((up_w, up_h), Image.LANCZOS)
200
+
201
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
202
+ orig_resized.save(tmp.name) # left image for slider
203
+ tmp.close()
204
+
205
+ tmp2 = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
206
+ upscaled.save(tmp2.name) # right image for slider
207
+ tmp2.close()
208
+
209
+ return tmp.name, tmp2.name
210
 
211
 
212
  # ===========================================================================
213
+ # Gradio callback
214
  # ===========================================================================
215
 
216
+ def run_upscale(input_img: Image.Image, model_name: str):
 
 
 
 
 
 
 
217
  if input_img is None:
218
+ return None, None, None
219
+
220
+ # Map display label β†’ key
221
+ key = next(k for k, v in MODEL_LABELS.items() if v == model_name)
222
+
223
+ result = upscale(input_img, key)
224
+ orig_path, up_path = make_comparison_png(input_img, result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ # Also save download copy
227
+ dl_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
228
+ result.save(dl_tmp.name, format="PNG")
229
+ dl_tmp.close()
230
+
231
+ return (orig_path, up_path), result, dl_tmp.name
232
 
233
 
234
  # ===========================================================================
235
+ # Gradio UI
236
  # ===========================================================================
237
 
238
  css = """
239
+ @import url('https://fonts.googleapis.com/css2?family=DM+Sans:wght@400;600;700&display=swap');
240
+
241
+ body, .gradio-container { font-family: 'DM Sans', sans-serif !important; }
242
+
243
+ #title {
244
+ text-align: center;
245
+ padding: 24px 0 8px;
246
  }
247
+ #title h1 {
248
+ font-size: 2rem;
249
+ font-weight: 700;
250
+ letter-spacing: -0.5px;
251
+ margin: 0;
252
+ }
253
+ #title p { color: #666; margin: 4px 0 0; }
254
+
255
  #run-btn {
256
+ background: linear-gradient(135deg, #0f0c29, #302b63, #24243e) !important;
257
+ color: #fff !important;
258
+ font-weight: 700 !important;
259
+ font-size: 1rem !important;
260
+ border-radius: 10px !important;
261
+ padding: 14px 0 !important;
262
+ width: 100%;
263
+ letter-spacing: 0.03em;
264
+ transition: opacity 0.2s;
265
  }
266
+ #run-btn:hover { opacity: 0.85; }
267
+
268
+ #dl-btn button {
269
+ background: #f4f4f4 !important;
270
+ border: 1px solid #ddd !important;
271
+ color: #333 !important;
272
+ border-radius: 8px !important;
273
+ width: 100%;
274
+ font-size: 0.85rem !important;
275
+ }
276
+
277
+ .section-label {
278
+ font-size: 0.75rem;
279
+ font-weight: 700;
280
+ letter-spacing: 0.1em;
281
+ text-transform: uppercase;
282
+ color: #999;
283
+ margin-bottom: 6px;
284
  }
 
285
  """
286
 
287
+ available_models = [v for k, v in MODEL_LABELS.items()
288
+ if k == "srcnn_x4" and SRCNN_MODEL is not None
289
+ or k in ONNX_SESSIONS]
290
+
291
+ # Always show all three in dropdown regardless of load state
292
+ # (shows error in output if model failed to load)
293
+ dropdown_choices = list(MODEL_LABELS.values())
294
 
295
+ with gr.Blocks(css=css, title="SpectraGAN Upscaler") as demo:
296
 
297
+ gr.HTML("""
298
+ <div id="title">
299
+ <h1>πŸ–ΌοΈ SpectraGAN Upscaler</h1>
300
+ <p>Choose a model, upscale your image, and drag the slider to compare.</p>
301
+ </div>
302
  """)
303
 
304
+ with gr.Row(equal_height=True):
305
+
306
+ # ── Left panel: controls ───────────────────────────────────────────
307
+ with gr.Column(scale=1, min_width=260):
308
+ gr.HTML('<div class="section-label">Source Image</div>')
309
+ inp_image = gr.Image(type="pil", show_label=False, height=260)
310
+
311
+ gr.HTML('<div class="section-label" style="margin-top:16px">Model</div>')
312
+ model_dropdown = gr.Dropdown(
313
+ choices=dropdown_choices,
314
+ value=dropdown_choices[0],
315
+ show_label=False,
316
+ )
317
+
318
+ run_btn = gr.Button("⚑ Upscale", elem_id="run-btn")
319
+
320
+ dl_btn = gr.DownloadButton(
321
+ label="⬇ Download upscaled PNG",
322
+ elem_id="dl-btn",
323
+ visible=True,
324
+ )
325
+
326
+ # ── Right panel: results ───────────────────────────────────────────
327
+ with gr.Column(scale=2):
328
+
329
+ gr.HTML('<div class="section-label">Before / After</div>')
330
+ slider = gr.ImageSlider(
331
+ label="Drag to compare",
332
+ show_label=False,
333
+ height=420,
334
+ type="filepath",
335
+ )
336
+
337
+ gr.HTML('<div class="section-label" style="margin-top:16px">Upscaled Preview</div>')
338
+ out_preview = gr.Image(
339
+ type="pil",
340
+ show_label=False,
341
+ height=200,
342
+ )
343
 
344
  run_btn.click(
345
+ fn=run_upscale,
346
+ inputs=[inp_image, model_dropdown],
347
+ outputs=[slider, out_preview, dl_btn],
348
  )
349
 
350
+ demo.launch(server_name="0.0.0.0", server_port=7860)