ParamDev commited on
Commit
7c042df
·
verified ·
1 Parent(s): 1808718

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +322 -237
app.py CHANGED
@@ -1,46 +1,28 @@
1
  import os
2
  import math
3
- import uuid
4
  import numpy as np
5
  import onnxruntime as ort
6
  from PIL import Image
7
  import gradio as gr
8
  import tempfile
9
- import requests
10
 
11
  # ---------------------------------------------------------------------------
12
- # Directory & model paths
13
  # ---------------------------------------------------------------------------
14
- MODEL_DIR = "model"
15
- os.makedirs(MODEL_DIR, exist_ok=True)
16
-
17
- MODEL_PATHS = {
18
- "esrgan_x2": os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx"),
19
- "esrgan_x4": os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx"),
20
- "srcnn_x4": os.path.join(MODEL_DIR, "SRCNN_x4.onnx"),
21
- "hresnet_x4": os.path.join(MODEL_DIR, "HResNet_x4.onnx"),
22
- }
23
 
24
  # ---------------------------------------------------------------------------
25
- # Google Drive file IDs
26
- # TODO: Replace the SRCNN / HResNet IDs with your own uploaded ONNX exports.
27
- # Steps to export SRCNN to ONNX:
28
- # import torch; from srcnn import SRCNN
29
- # model = SRCNN(); model.load_state_dict(torch.load("srcnn.pth"))
30
- # dummy = torch.randn(1, 3, 128, 128)
31
- # torch.onnx.export(model, dummy, "SRCNN_x4.onnx",
32
- # input_names=["input"], output_names=["output"],
33
- # dynamic_axes={"input":{2:"H",3:"W"},"output":{2:"H",3:"W"}})
34
- # Same pattern applies for HResNet / EDSR.
35
  # ---------------------------------------------------------------------------
36
- DRIVE_IDS = {
37
- "esrgan_x2": "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6",
38
- "esrgan_x4": "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6",
39
- "srcnn_x4": "YOUR_SRCNN_DRIVE_FILE_ID_HERE", # <-- replace
40
- "hresnet_x4": "YOUR_HRESNET_DRIVE_FILE_ID_HERE", # <-- replace
41
  }
42
 
43
- # Scale factor each model produces
44
  MODEL_SCALES = {
45
  "esrgan_x2": 2,
46
  "esrgan_x4": 4,
@@ -48,7 +30,6 @@ MODEL_SCALES = {
48
  "hresnet_x4": 4,
49
  }
50
 
51
- # Human-readable labels shown in the UI
52
  MODEL_LABELS = {
53
  "esrgan_x2": "Real-ESRGAN ×2",
54
  "esrgan_x4": "Real-ESRGAN ×4",
@@ -56,167 +37,316 @@ MODEL_LABELS = {
56
  "hresnet_x4": "HResNet ×4",
57
  }
58
 
59
- # ---------------------------------------------------------------------------
60
- # SR3 NOTE
61
- # ---------------------------------------------------------------------------
62
- # SR3 (Super-Resolution via Repeated Refinement, Ho et al. 2022) is a
63
- # *diffusion model* and cannot be exported to ONNX in the same way as
64
- # feed-forward CNNs. To integrate SR3:
65
- # 1. Clone the official repo: https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement
66
- # 2. Place your checkpoint in model/sr3_x4.pth
67
- # 3. Add a `run_sr3(image: Image) -> Image` function that runs the
68
- # denoising loop using PyTorch directly (add `torch` to requirements).
69
- # 4. Wire `run_sr3` into `upscale_one_model` for key "sr3_x4".
70
- # SR3 is intentionally omitted from the ONNX pipeline to avoid misleading
71
- # model shape assumptions.
72
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
74
 
75
- # ---------------------------------------------------------------------------
76
- # Google Drive downloader
77
- # ---------------------------------------------------------------------------
78
- def download_from_drive(file_id: str, dest_path: str):
79
- URL = "https://drive.google.com/uc?export=download"
80
- session = requests.Session()
81
- response = session.get(URL, params={"id": file_id}, stream=True)
82
- token = None
83
- for key, value in response.cookies.items():
84
- if key.startswith("download_warning"):
85
- token = value
86
- break
87
- if token:
88
- response = session.get(URL, params={"id": file_id, "confirm": token}, stream=True)
89
- os.makedirs(os.path.dirname(dest_path), exist_ok=True)
90
- with open(dest_path, "wb") as f:
91
- for chunk in response.iter_content(chunk_size=32768):
92
- if chunk:
93
- f.write(chunk)
94
- print(f"Downloaded → {dest_path}")
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # ---------------------------------------------------------------------------
98
- # LOCAL_ONLY mode
99
- # ---------------------------------------------------------------------------
100
- # Set to True after running export_models.py to load ONNX files straight
101
- # from the model/ directory without needing any Google Drive IDs.
102
- # Set to False (and fill in DRIVE_IDS) to download from Drive as before.
103
- LOCAL_ONLY = True
104
 
105
- # ---------------------------------------------------------------------------
106
- # Download models if missing
107
- # ---------------------------------------------------------------------------
108
- for key, path in MODEL_PATHS.items():
109
- if os.path.isfile(path):
110
- continue # already on disk — nothing to do
111
 
112
- if LOCAL_ONLY:
113
- print(f"[INFO] {key}: file not found at {path}. "
114
- "Run export_models.py first, or set LOCAL_ONLY=False to use Drive.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  continue
116
 
117
- file_id = DRIVE_IDS[key]
118
- if file_id.startswith("YOUR_"):
119
- print(f"[WARN] Skipping {key}: Google Drive ID not set. "
120
- "Run export_models.py or update DRIVE_IDS in app.py.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  else:
122
- print(f"Downloading {MODEL_LABELS[key]} ")
123
- download_from_drive(file_id, path)
124
 
125
 
126
- # ---------------------------------------------------------------------------
127
- # Load ONNX sessions (only for models that have been downloaded)
128
- # ---------------------------------------------------------------------------
 
129
  sess_opts = ort.SessionOptions()
130
  sess_opts.intra_op_num_threads = 2
131
  sess_opts.inter_op_num_threads = 2
132
 
133
- SESSIONS = {} # key → ort.InferenceSession
134
- INPUT_SHAPES = {} # key → (H_in, W_in)
135
 
136
  for key, path in MODEL_PATHS.items():
137
- if os.path.isfile(path):
138
- try:
139
- sess = ort.InferenceSession(
140
- path,
141
- sess_options=sess_opts,
142
- providers=["CPUExecutionProvider"]
143
- )
144
- meta = sess.get_inputs()[0]
145
- shape = tuple(meta.shape)
146
- # shape is (1, 3, H, W) for fixed-size models, or (1,3,None,None) for dynamic
147
- h = int(shape[2]) if shape[2] is not None and str(shape[2]).isdigit() else 128
148
- w = int(shape[3]) if shape[3] is not None and str(shape[3]).isdigit() else 128
149
- SESSIONS[key] = (sess, meta)
150
- INPUT_SHAPES[key] = (h, w)
151
- print(f"Loaded {MODEL_LABELS[key]} tile={h}×{w}")
152
- except Exception as e:
153
- print(f"[ERROR] Could not load {key}: {e}")
154
- else:
155
- print(f"[INFO] {key} not available will be skipped in comparisons.")
156
-
157
 
158
- # ---------------------------------------------------------------------------
159
- # Tile-based upscale for a single ONNX model
160
- # ---------------------------------------------------------------------------
161
  def run_onnx_tile(sess, meta, tile_np: np.ndarray) -> np.ndarray:
162
- """Run one tile through any ONNX session. tile_np is HWC float32 [0,1]."""
163
- patch = np.transpose(tile_np, (2, 0, 1))[None, ...] # NCHW
164
  out = sess.run(None, {meta.name: patch})[0]
165
- out = np.squeeze(out, axis=0)
166
- return np.transpose(out, (1, 2, 0)) # back to HWC
167
 
168
 
169
  def tile_upscale_model(input_img: Image.Image, key: str, max_dim: int = 1024) -> Image.Image:
170
- """
171
- Upscale *input_img* using the ONNX model identified by *key*.
172
- Returns a PIL Image at the upscaled resolution.
173
- """
174
  if key not in SESSIONS:
175
- raise ValueError(f"Model '{key}' is not loaded. Check the Drive ID / path.")
176
 
177
- sess, meta = SESSIONS[key]
178
- H_in, W_in = INPUT_SHAPES[key]
179
- scale = MODEL_SCALES[key]
180
 
181
- # Optionally cap input size to avoid OOM on large images
182
  w, h = input_img.size
183
  if w > max_dim or h > max_dim:
184
- factor = max_dim / float(max(w, h))
185
  input_img = input_img.resize((int(w * factor), int(h * factor)), Image.LANCZOS)
186
 
187
- arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0
188
  h_orig, w_orig, _ = arr.shape
 
 
189
 
190
- tiles_h = math.ceil(h_orig / H_in)
191
- tiles_w = math.ceil(w_orig / W_in)
192
- pad_h = tiles_h * H_in - h_orig
193
- pad_w = tiles_w * W_in - w_orig
194
-
195
- arr_padded = np.pad(arr, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
196
- out_arr = np.zeros((tiles_h * H_in * scale, tiles_w * W_in * scale, 3), dtype=np.float32)
197
-
198
  for i in range(tiles_h):
199
  for j in range(tiles_w):
200
  y0, x0 = i * H_in, j * W_in
201
- tile = arr_padded[y0:y0 + H_in, x0:x0 + W_in, :]
202
- up_tile = run_onnx_tile(sess, meta, tile)
203
- oy0, ox0 = i * H_in * scale, j * W_in * scale
204
- out_arr[oy0:oy0 + H_in * scale, ox0:ox0 + W_in * scale, :] = up_tile
205
 
206
- final = np.clip(out_arr[0:h_orig * scale, 0:w_orig * scale, :], 0.0, 1.0)
207
  return Image.fromarray((final * 255.0).round().astype(np.uint8))
208
 
209
 
210
- def upscale_8x_from_4x(input_img: Image.Image, key: str) -> Image.Image:
211
- """Run ×4 model, then bicubic ×2 to reach ×8."""
212
  img_4x = tile_upscale_model(input_img, key)
213
  w, h = input_img.size
214
  return img_4x.resize((w * 8, h * 8), Image.LANCZOS)
215
 
216
 
217
- # ---------------------------------------------------------------------------
218
- # Core comparison function (called by the Gradio button)
219
- # ---------------------------------------------------------------------------
 
220
  def compare_models(
221
  input_img: Image.Image,
222
  use_esrgan_x2: bool,
@@ -226,36 +356,32 @@ def compare_models(
226
  include_8x: bool,
227
  ):
228
  if input_img is None:
229
- return [None] * 8 # 4 preview + 4 download slots
230
-
231
- selection = []
232
- if use_esrgan_x2: selection.append("esrgan_x2")
233
- if use_esrgan_x4: selection.append("esrgan_x4")
234
- if use_srcnn: selection.append("srcnn_x4")
235
- if use_hresnet: selection.append("hresnet_x4")
236
-
237
- previews = []
238
- downloads = []
 
239
 
240
  for key in selection:
241
  if key not in SESSIONS:
242
  previews.append(None)
243
- downloads.append(gr.DownloadButton(label=f"{MODEL_LABELS[key]} – not loaded",
244
- visible=True, value=None))
245
  continue
246
-
247
  try:
248
- if include_8x and MODEL_SCALES[key] == 4:
249
- result = upscale_8x_from_4x(input_img, key)
250
- suffix = "×8"
251
- else:
252
- result = tile_upscale_model(input_img, key)
253
- suffix = f"×{MODEL_SCALES[key]}"
254
-
255
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
256
  result.save(tmp.name, format="PNG")
257
  tmp.close()
258
-
259
  previews.append(result)
260
  downloads.append(tmp.name)
261
  except Exception as e:
@@ -263,55 +389,38 @@ def compare_models(
263
  previews.append(None)
264
  downloads.append(None)
265
 
266
- # Pad to always return exactly 4 preview + 4 download values
267
  while len(previews) < 4: previews.append(None)
268
  while len(downloads) < 4: downloads.append(None)
 
269
 
270
- return previews + downloads # 8-element list
271
 
 
 
 
272
 
273
- # ---------------------------------------------------------------------------
274
- # Gradio UI – side-by-side comparison layout
275
- # ---------------------------------------------------------------------------
276
  css = """
277
  body { font-family: 'Segoe UI', sans-serif; }
278
-
279
  .panel-title {
280
- text-align: center;
281
- font-weight: 700;
282
- font-size: 0.85rem;
283
- letter-spacing: 0.08em;
284
- text-transform: uppercase;
285
- margin-bottom: 4px;
286
- color: #555;
287
  }
288
-
289
  #run-btn {
290
  background: linear-gradient(135deg, #1a1a2e, #16213e) !important;
291
- color: #e2e2e2 !important;
292
- font-size: 1rem !important;
293
- font-weight: 600 !important;
294
- border-radius: 8px !important;
295
  padding: 12px 28px !important;
296
  }
297
- #run-btn:hover {
298
- background: linear-gradient(135deg, #0f3460, #533483) !important;
299
- }
300
-
301
  .dl-btn button {
302
- background: #f0f4ff !important;
303
- border: 1px solid #c5d0f5 !important;
304
- color: #333 !important;
305
- font-size: 0.78rem !important;
306
- border-radius: 6px !important;
307
- width: 100%;
308
  }
309
-
310
  .model-toggle label { font-size: 0.9rem; }
311
  """
312
 
313
- ALL_KEYS = ["esrgan_x2", "esrgan_x4", "srcnn_x4", "hresnet_x4"]
314
- PANEL_KEYS = ALL_KEYS # order for the 4 comparison panels
315
 
316
  with gr.Blocks(css=css, title="SpectraGAN — Multi-Model Comparison") as demo:
317
 
@@ -320,59 +429,35 @@ with gr.Blocks(css=css, title="SpectraGAN — Multi-Model Comparison") as demo:
320
  Upload an image, select models, and compare results side by side.
321
  """)
322
 
323
- # ── Input row ──────────────────────────────────────────────────────────
324
  with gr.Row():
325
  inp_image = gr.Image(type="pil", label="Source Image", scale=2)
326
-
327
  with gr.Column(scale=1):
328
  gr.Markdown("### Models to compare")
329
- chk_esrgan_x2 = gr.Checkbox(label="Real-ESRGAN ×2", value=True, elem_classes="model-toggle")
330
- chk_esrgan_x4 = gr.Checkbox(label="Real-ESRGAN ×4", value=True, elem_classes="model-toggle")
331
- chk_srcnn = gr.Checkbox(label="SRCNN ×4", value=True, elem_classes="model-toggle")
332
- chk_hresnet = gr.Checkbox(label="HResNet ×4", value=True, elem_classes="model-toggle")
333
-
334
  gr.Markdown("### Options")
335
- chk_8x = gr.Checkbox(label="Also apply ×8 post-resize on ×4 models", value=False)
336
-
337
  run_btn = gr.Button("⚡ Run Comparison", elem_id="run-btn")
338
 
339
- # ── Comparison grid ────────────────────────────────────────────────────
340
- gr.Markdown("---")
341
- gr.Markdown("## Results")
342
-
343
- previews = []
344
- dl_btns = []
345
 
346
  with gr.Row():
347
  for key in PANEL_KEYS:
348
  with gr.Column():
349
  gr.HTML(f'<div class="panel-title">{MODEL_LABELS[key]}</div>')
350
- img_out = gr.Image(
351
- type="pil",
352
- label=MODEL_LABELS[key],
353
- show_label=False,
354
- height=320,
355
- )
356
- dl_out = gr.DownloadButton(
357
- label="⬇ Download PNG",
358
- elem_classes="dl-btn",
359
- visible=True,
360
- )
361
- previews.append(img_out)
362
- dl_btns.append(dl_out)
363
-
364
- # ── Wire up ─────────────────────────────────────────────────────────────
365
  run_btn.click(
366
  fn=compare_models,
367
- inputs=[
368
- inp_image,
369
- chk_esrgan_x2,
370
- chk_esrgan_x4,
371
- chk_srcnn,
372
- chk_hresnet,
373
- chk_8x,
374
- ],
375
- outputs=previews + dl_btns, # 8 outputs: 4 images + 4 download buttons
376
  )
377
 
378
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
  import math
 
3
  import numpy as np
4
  import onnxruntime as ort
5
  from PIL import Image
6
  import gradio as gr
7
  import tempfile
8
+ from huggingface_hub import hf_hub_download, HfApi
9
 
10
  # ---------------------------------------------------------------------------
11
+ # CONFIG only these two lines ever need changing
12
  # ---------------------------------------------------------------------------
13
+ HF_REPO_ID = "ParamDev/spectraGAN" # your HF model repo
14
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set this in Space → Settings → Secrets
 
 
 
 
 
 
 
15
 
16
  # ---------------------------------------------------------------------------
17
+ # Filenames used in the HF repo
 
 
 
 
 
 
 
 
 
18
  # ---------------------------------------------------------------------------
19
+ HF_FILENAMES = {
20
+ "esrgan_x2": "Real-ESRGAN_x2plus.onnx",
21
+ "esrgan_x4": "Real-ESRGAN-x4plus.onnx",
22
+ "srcnn_x4": "SRCNN_x4.onnx",
23
+ "hresnet_x4": "HResNet_x4.onnx",
24
  }
25
 
 
26
  MODEL_SCALES = {
27
  "esrgan_x2": 2,
28
  "esrgan_x4": 4,
 
30
  "hresnet_x4": 4,
31
  }
32
 
 
33
  MODEL_LABELS = {
34
  "esrgan_x2": "Real-ESRGAN ×2",
35
  "esrgan_x4": "Real-ESRGAN ×4",
 
37
  "hresnet_x4": "HResNet ×4",
38
  }
39
 
40
+ # Local scratch directory (writable in HF Spaces)
41
+ CACHE_DIR = "/tmp/spectragan_models"
42
+ os.makedirs(CACHE_DIR, exist_ok=True)
43
+
44
+
45
+ # ===========================================================================
46
+ # STEP 1 ONNX export for SRCNN and HResNet
47
+ # Only runs if the file is missing from both local cache and the HF repo.
48
+ # Uses torch which is available in HF Spaces by default.
49
+ # ===========================================================================
50
+
51
+ def _export_srcnn_onnx(out_path: str):
52
+ import torch
53
+ import torch.nn as nn
54
+
55
+ class SRCNN(nn.Module):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
59
+ self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
60
+ self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
61
+ self.relu = nn.ReLU(inplace=True)
62
+
63
+ def forward(self, x):
64
+ return self.conv3(self.relu(self.conv2(self.relu(self.conv1(x)))))
65
+
66
+ class SRCNNx4(nn.Module):
67
+ """Bicubic ×4 upsample → SRCNN refinement."""
68
+ def __init__(self):
69
+ super().__init__()
70
+ self.srcnn = SRCNN()
71
+
72
+ def forward(self, x):
73
+ up = torch.nn.functional.interpolate(
74
+ x, scale_factor=4, mode="bicubic", align_corners=False
75
+ )
76
+ return self.srcnn(up)
77
 
78
+ model = SRCNNx4()
79
 
80
+ # Try to load pretrained weights; fall back to random init if unavailable
81
+ weights_url = (
82
+ "https://github.com/yjn870/SRCNN-pytorch/raw/master/models/srcnn_x4.pth"
83
+ )
84
+ weights_path = os.path.join(CACHE_DIR, "srcnn_x4.pth")
85
+ if not os.path.exists(weights_path):
86
+ try:
87
+ import urllib.request
88
+ print(" Downloading SRCNN pretrained weights …")
89
+ urllib.request.urlretrieve(weights_url, weights_path)
90
+ except Exception as e:
91
+ print(f" [WARN] Could not download SRCNN weights ({e}). Using random init.")
92
+ weights_path = None
 
 
 
 
 
 
 
93
 
94
+ if weights_path and os.path.exists(weights_path):
95
+ state = torch.load(weights_path, map_location="cpu")
96
+ try:
97
+ model.srcnn.load_state_dict(state, strict=True)
98
+ print(" SRCNN pretrained weights loaded ✓")
99
+ except RuntimeError:
100
+ model.srcnn.load_state_dict(state, strict=False)
101
+ print(" SRCNN weights loaded (partial) ✓")
102
+
103
+ model.eval()
104
+ dummy = torch.zeros(1, 3, 128, 128)
105
+ torch.onnx.export(
106
+ model, dummy, out_path,
107
+ opset_version=17,
108
+ input_names=["input"], output_names=["output"],
109
+ dynamic_axes={"input": {2: "H", 3: "W"}, "output": {2: "H", 3: "W"}},
110
+ )
111
+ print(f" Exported SRCNN → {out_path}")
112
 
 
 
 
 
 
 
 
113
 
114
+ def _export_hresnet_onnx(out_path: str):
115
+ import torch
116
+ import torch.nn as nn
 
 
 
117
 
118
+ class ResBlock(nn.Module):
119
+ def __init__(self, n_feats):
120
+ super().__init__()
121
+ self.body = nn.Sequential(
122
+ nn.Conv2d(n_feats, n_feats, 3, padding=1),
123
+ nn.ReLU(inplace=True),
124
+ nn.Conv2d(n_feats, n_feats, 3, padding=1),
125
+ )
126
+ def forward(self, x):
127
+ return x + self.body(x)
128
+
129
+ class EDSR(nn.Module):
130
+ """EDSR-baseline: 16 residual blocks, 64 channels, ×4."""
131
+ def __init__(self):
132
+ super().__init__()
133
+ n_feats = 64
134
+ self.head = nn.Conv2d(3, n_feats, 3, padding=1)
135
+ self.body = nn.Sequential(*[ResBlock(n_feats) for _ in range(16)])
136
+ self.body_tail = nn.Conv2d(n_feats, n_feats, 3, padding=1)
137
+ # ×4 via two ×2 pixel-shuffle stages
138
+ self.tail = nn.Sequential(
139
+ nn.Conv2d(n_feats, n_feats * 4, 3, padding=1),
140
+ nn.PixelShuffle(2),
141
+ nn.Conv2d(n_feats, n_feats * 4, 3, padding=1),
142
+ nn.PixelShuffle(2),
143
+ nn.Conv2d(n_feats, 3, 3, padding=1),
144
+ )
145
+
146
+ def forward(self, x):
147
+ x = self.head(x)
148
+ res = self.body_tail(self.body(x))
149
+ return self.tail(x + res)
150
+
151
+ model = EDSR()
152
+
153
+ # Try to load pretrained weights from HuggingFace
154
+ weights_url = (
155
+ "https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model_4x.pt"
156
+ )
157
+ weights_path = os.path.join(CACHE_DIR, "edsr_x4.pt")
158
+ if not os.path.exists(weights_path):
159
+ try:
160
+ import urllib.request
161
+ print(" Downloading EDSR pretrained weights from HuggingFace …")
162
+ urllib.request.urlretrieve(weights_url, weights_path)
163
+ except Exception as e:
164
+ print(f" [WARN] Could not download EDSR weights ({e}). Using random init.")
165
+ weights_path = None
166
+
167
+ if weights_path and os.path.exists(weights_path):
168
+ state = torch.load(weights_path, map_location="cpu")
169
+ if "model" in state: state = state["model"]
170
+ if "state_dict" in state: state = state["state_dict"]
171
+ state = {k.replace("module.", ""): v for k, v in state.items()}
172
+ try:
173
+ model.load_state_dict(state, strict=True)
174
+ print(" EDSR pretrained weights loaded ✓")
175
+ except RuntimeError:
176
+ model.load_state_dict(state, strict=False)
177
+ print(" EDSR weights loaded (partial) ✓")
178
+
179
+ model.eval()
180
+ dummy = torch.zeros(1, 3, 128, 128)
181
+ torch.onnx.export(
182
+ model, dummy, out_path,
183
+ opset_version=17,
184
+ input_names=["input"], output_names=["output"],
185
+ dynamic_axes={"input": {2: "H", 3: "W"}, "output": {2: "H", 3: "W"}},
186
+ )
187
+ print(f" Exported HResNet → {out_path}")
188
+
189
+
190
+ # Map model keys that need local export to their export functions
191
+ EXPORTERS = {
192
+ "srcnn_x4": _export_srcnn_onnx,
193
+ "hresnet_x4": _export_hresnet_onnx,
194
+ }
195
+
196
+
197
+ def _upload_to_hub(local_path: str, filename: str):
198
+ """Upload a file to the HF repo so future Space restarts skip the export."""
199
+ if not HF_TOKEN:
200
+ print(f" [INFO] HF_TOKEN not set — skipping upload of {filename}.")
201
+ print(f" Set HF_TOKEN in Space Secrets to persist generated models.")
202
+ return
203
+ try:
204
+ api = HfApi()
205
+ api.upload_file(
206
+ path_or_fileobj=local_path,
207
+ path_in_repo=filename,
208
+ repo_id=HF_REPO_ID,
209
+ repo_type="model",
210
+ token=HF_TOKEN,
211
+ )
212
+ print(f" Uploaded {filename} → {HF_REPO_ID} ✓")
213
+ except Exception as e:
214
+ print(f" [WARN] Upload failed for {filename}: {e}")
215
+
216
+
217
+ # ===========================================================================
218
+ # STEP 2 — Fetch or generate all models
219
+ # Priority: HF Hub cache → local export → upload back to Hub
220
+ # ===========================================================================
221
+
222
+ MODEL_PATHS = {} # key → local file path ready for onnxruntime
223
+
224
+ for key, filename in HF_FILENAMES.items():
225
+ local_path = os.path.join(CACHE_DIR, filename)
226
+
227
+ # Already cached locally from a previous run this session
228
+ if os.path.exists(local_path):
229
+ print(f"[CACHE] {MODEL_LABELS[key]} already in /tmp cache.")
230
+ MODEL_PATHS[key] = local_path
231
  continue
232
 
233
+ # Try downloading from HF Hub first
234
+ try:
235
+ print(f"Fetching {MODEL_LABELS[key]} from HF Hub ")
236
+ dl_path = hf_hub_download(
237
+ repo_id=HF_REPO_ID,
238
+ filename=filename,
239
+ repo_type="model",
240
+ token=HF_TOKEN,
241
+ local_dir=CACHE_DIR,
242
+ )
243
+ MODEL_PATHS[key] = dl_path
244
+ print(f" ✓ {dl_path}")
245
+ continue
246
+ except Exception as e:
247
+ print(f" Not found on Hub ({e})")
248
+
249
+ # Not on Hub — export it locally if we have an exporter for it
250
+ if key in EXPORTERS:
251
+ print(f"Generating {MODEL_LABELS[key]} via local export …")
252
+ try:
253
+ EXPORTERS[key](local_path)
254
+ MODEL_PATHS[key] = local_path
255
+ # Upload back to Hub so next restart skips this step
256
+ print(f" Uploading to Hub for future restarts …")
257
+ _upload_to_hub(local_path, filename)
258
+ except Exception as e:
259
+ print(f" [ERROR] Export failed for {key}: {e}")
260
  else:
261
+ print(f" [WARN] {key} not on Hub and no exporter available — skipping.")
 
262
 
263
 
264
+ # ===========================================================================
265
+ # STEP 3 Load ONNX sessions
266
+ # ===========================================================================
267
+
268
  sess_opts = ort.SessionOptions()
269
  sess_opts.intra_op_num_threads = 2
270
  sess_opts.inter_op_num_threads = 2
271
 
272
+ SESSIONS = {}
273
+ INPUT_SHAPES = {}
274
 
275
  for key, path in MODEL_PATHS.items():
276
+ try:
277
+ sess = ort.InferenceSession(
278
+ path,
279
+ sess_options=sess_opts,
280
+ providers=["CPUExecutionProvider"],
281
+ )
282
+ meta = sess.get_inputs()[0]
283
+ shape = tuple(meta.shape)
284
+ h = int(shape[2]) if isinstance(shape[2], int) else 128
285
+ w = int(shape[3]) if isinstance(shape[3], int) else 128
286
+ SESSIONS[key] = (sess, meta)
287
+ INPUT_SHAPES[key] = (h, w)
288
+ print(f"Loaded {MODEL_LABELS[key]} tile={h}×{w}")
289
+ except Exception as e:
290
+ print(f"[ERROR] Could not load {key}: {e}")
291
+
292
+
293
+ # ===========================================================================
294
+ # STEP 4Inference helpers
295
+ # ===========================================================================
296
 
 
 
 
297
  def run_onnx_tile(sess, meta, tile_np: np.ndarray) -> np.ndarray:
298
+ patch = np.transpose(tile_np, (2, 0, 1))[None, ...]
 
299
  out = sess.run(None, {meta.name: patch})[0]
300
+ return np.transpose(np.squeeze(out, 0), (1, 2, 0))
 
301
 
302
 
303
  def tile_upscale_model(input_img: Image.Image, key: str, max_dim: int = 1024) -> Image.Image:
 
 
 
 
304
  if key not in SESSIONS:
305
+ raise ValueError(f"Model '{key}' is not loaded.")
306
 
307
+ sess, meta = SESSIONS[key]
308
+ H_in, W_in = INPUT_SHAPES[key]
309
+ scale = MODEL_SCALES[key]
310
 
 
311
  w, h = input_img.size
312
  if w > max_dim or h > max_dim:
313
+ factor = max_dim / float(max(w, h))
314
  input_img = input_img.resize((int(w * factor), int(h * factor)), Image.LANCZOS)
315
 
316
+ arr = np.array(input_img.convert("RGB")).astype(np.float32) / 255.0
317
  h_orig, w_orig, _ = arr.shape
318
+ tiles_h = math.ceil(h_orig / H_in)
319
+ tiles_w = math.ceil(w_orig / W_in)
320
 
321
+ arr_padded = np.pad(
322
+ arr,
323
+ ((0, tiles_h * H_in - h_orig), (0, tiles_w * W_in - w_orig), (0, 0)),
324
+ mode="reflect",
325
+ )
326
+ out_arr = np.zeros(
327
+ (tiles_h * H_in * scale, tiles_w * W_in * scale, 3), dtype=np.float32
328
+ )
329
  for i in range(tiles_h):
330
  for j in range(tiles_w):
331
  y0, x0 = i * H_in, j * W_in
332
+ up_tile = run_onnx_tile(sess, meta, arr_padded[y0:y0+H_in, x0:x0+W_in])
333
+ oy0, ox0 = i * H_in * scale, j * W_in * scale
334
+ out_arr[oy0:oy0+H_in*scale, ox0:ox0+W_in*scale] = up_tile
 
335
 
336
+ final = np.clip(out_arr[:h_orig*scale, :w_orig*scale], 0.0, 1.0)
337
  return Image.fromarray((final * 255.0).round().astype(np.uint8))
338
 
339
 
340
+ def upscale_8x(input_img: Image.Image, key: str) -> Image.Image:
 
341
  img_4x = tile_upscale_model(input_img, key)
342
  w, h = input_img.size
343
  return img_4x.resize((w * 8, h * 8), Image.LANCZOS)
344
 
345
 
346
+ # ===========================================================================
347
+ # STEP 5 Gradio comparison callback
348
+ # ===========================================================================
349
+
350
  def compare_models(
351
  input_img: Image.Image,
352
  use_esrgan_x2: bool,
 
356
  include_8x: bool,
357
  ):
358
  if input_img is None:
359
+ return [None] * 8
360
+
361
+ wanted = {
362
+ "esrgan_x2": use_esrgan_x2,
363
+ "esrgan_x4": use_esrgan_x4,
364
+ "srcnn_x4": use_srcnn,
365
+ "hresnet_x4": use_hresnet,
366
+ }
367
+ selection = [k for k, v in wanted.items() if v]
368
+ previews = []
369
+ downloads = []
370
 
371
  for key in selection:
372
  if key not in SESSIONS:
373
  previews.append(None)
374
+ downloads.append(None)
 
375
  continue
 
376
  try:
377
+ result = (
378
+ upscale_8x(input_img, key)
379
+ if include_8x and MODEL_SCALES[key] == 4
380
+ else tile_upscale_model(input_img, key)
381
+ )
 
 
382
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
383
  result.save(tmp.name, format="PNG")
384
  tmp.close()
 
385
  previews.append(result)
386
  downloads.append(tmp.name)
387
  except Exception as e:
 
389
  previews.append(None)
390
  downloads.append(None)
391
 
 
392
  while len(previews) < 4: previews.append(None)
393
  while len(downloads) < 4: downloads.append(None)
394
+ return previews + downloads
395
 
 
396
 
397
+ # ===========================================================================
398
+ # STEP 6 — Gradio UI
399
+ # ===========================================================================
400
 
 
 
 
401
  css = """
402
  body { font-family: 'Segoe UI', sans-serif; }
 
403
  .panel-title {
404
+ text-align: center; font-weight: 700; font-size: 0.85rem;
405
+ letter-spacing: 0.08em; text-transform: uppercase;
406
+ margin-bottom: 4px; color: #555;
 
 
 
 
407
  }
 
408
  #run-btn {
409
  background: linear-gradient(135deg, #1a1a2e, #16213e) !important;
410
+ color: #e2e2e2 !important; font-size: 1rem !important;
411
+ font-weight: 600 !important; border-radius: 8px !important;
 
 
412
  padding: 12px 28px !important;
413
  }
414
+ #run-btn:hover { background: linear-gradient(135deg, #0f3460, #533483) !important; }
 
 
 
415
  .dl-btn button {
416
+ background: #f0f4ff !important; border: 1px solid #c5d0f5 !important;
417
+ color: #333 !important; font-size: 0.78rem !important;
418
+ border-radius: 6px !important; width: 100%;
 
 
 
419
  }
 
420
  .model-toggle label { font-size: 0.9rem; }
421
  """
422
 
423
+ PANEL_KEYS = ["esrgan_x2", "esrgan_x4", "srcnn_x4", "hresnet_x4"]
 
424
 
425
  with gr.Blocks(css=css, title="SpectraGAN — Multi-Model Comparison") as demo:
426
 
 
429
  Upload an image, select models, and compare results side by side.
430
  """)
431
 
 
432
  with gr.Row():
433
  inp_image = gr.Image(type="pil", label="Source Image", scale=2)
 
434
  with gr.Column(scale=1):
435
  gr.Markdown("### Models to compare")
436
+ chk_esrgan_x2 = gr.Checkbox(label="Real-ESRGAN ×2", value=True, elem_classes="model-toggle")
437
+ chk_esrgan_x4 = gr.Checkbox(label="Real-ESRGAN ×4", value=True, elem_classes="model-toggle")
438
+ chk_srcnn = gr.Checkbox(label="SRCNN ×4", value=True, elem_classes="model-toggle")
439
+ chk_hresnet = gr.Checkbox(label="HResNet ×4", value=True, elem_classes="model-toggle")
 
440
  gr.Markdown("### Options")
441
+ chk_8x = gr.Checkbox(label="Also apply ×8 post-resize on ×4 models", value=False)
 
442
  run_btn = gr.Button("⚡ Run Comparison", elem_id="run-btn")
443
 
444
+ gr.Markdown("---\n## Results")
445
+ previews = []
446
+ dl_btns = []
 
 
 
447
 
448
  with gr.Row():
449
  for key in PANEL_KEYS:
450
  with gr.Column():
451
  gr.HTML(f'<div class="panel-title">{MODEL_LABELS[key]}</div>')
452
+ previews.append(gr.Image(type="pil", show_label=False, height=320))
453
+ dl_btns.append(gr.DownloadButton(
454
+ label="⬇ Download PNG", elem_classes="dl-btn", visible=True
455
+ ))
456
+
 
 
 
 
 
 
 
 
 
 
457
  run_btn.click(
458
  fn=compare_models,
459
+ inputs=[inp_image, chk_esrgan_x2, chk_esrgan_x4, chk_srcnn, chk_hresnet, chk_8x],
460
+ outputs=previews + dl_btns,
 
 
 
 
 
 
 
461
  )
462
 
463
  demo.launch(server_name="0.0.0.0", server_port=7860)