JS6969 commited on
Commit
9bf416a
Β·
verified Β·
1 Parent(s): 01d8863

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -150
app.py CHANGED
@@ -1,6 +1,6 @@
1
- # app.py β€” MjΓΆlnir Β· Upscale Images (Real-ESRGAN) β€” GPU/CPU safe
2
 
3
- # ---- TorchVision shim (keeps basicsr happy if torchvision isn't installed) ----
4
  import sys, types
5
  try:
6
  import torchvision.transforms.functional_tensor as _ft # noqa: F401
@@ -17,32 +17,29 @@ except Exception:
17
  return torch.stack([gray, gray, gray], dim=-3) if num_output_channels == 3 else gray.unsqueeze(-3)
18
  _mod.rgb_to_grayscale = rgb_to_grayscale
19
  sys.modules["torchvision.transforms.functional_tensor"] = _mod
20
- # ------------------------------------------------------------------------------
21
 
22
  import os, time, zipfile, tempfile, shutil, base64
23
  from pathlib import Path
24
- from typing import List, Optional
25
- import re
26
  import gradio as gr
27
  import numpy as np
28
  import cv2
29
  from PIL import Image
30
 
31
- import torch
 
 
 
32
  from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
33
  from basicsr.utils.download_util import load_file_from_url
34
  from realesrgan import RealESRGANer
35
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
36
 
37
  # ────────────────────────────────────────────────────────
38
- # Small utils
39
  # ────────────────────────────────────────────────────────
40
-
41
- def have_gpu() -> bool:
42
- return torch.cuda.is_available()
43
-
44
- print("βœ… GPU available" if have_gpu() else "⚠️ No GPU detected. Will use CPU (slow).")
45
-
46
  def try_load_logo_b64() -> str:
47
  try:
48
  with open("bifrost_logo.png", "rb") as f:
@@ -58,12 +55,16 @@ def render_logo_html(px: int = 96) -> str:
58
  {img}
59
  <div>
60
  <div style="font-size:1.6rem;font-weight:800;">MjΓΆlnir Β· Upscale Images</div>
61
- <div style="opacity:0.8;">Real-ESRGAN (batch click with progress)</div>
62
  </div>
63
  </div>
64
  <hr>
65
  """
66
 
 
 
 
 
67
  _num = re.compile(r'(\d+)')
68
  def _natural_key(p: Path | str):
69
  s = str(p)
@@ -77,9 +78,8 @@ def sample_paths(paths: List[Path] | List[str], n: int = 30) -> List[str]:
77
  step = (total - 1) / (n - 1); idxs = [round(i * step) for i in range(n)]
78
  out, seen = [], set()
79
  for i in idxs:
80
- i = int(i)
81
- if i not in seen:
82
- out.append(str(paths[i])); seen.add(i)
83
  return out
84
 
85
  def render_progress(pct: float, label: str = "") -> str:
@@ -88,10 +88,27 @@ def render_progress(pct: float, label: str = "") -> str:
88
  <div style="height:100%;width:{pct:.1f}%;background:#3b82f6;"></div></div>
89
  <div style="font-size:12px;opacity:.8;margin-top:4px;">{label} {pct:.1f}%</div>'''
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # ────────────────────────────────────────────────────────
92
- # Real-ESRGAN wiring (GPU/CPU safe)
93
  # ────────────────────────────────────────────────────────
94
-
95
  def build_rrdb(scale: int, num_block: int):
96
  return _RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=scale)
97
 
@@ -100,128 +117,63 @@ def _weights_dir() -> str:
100
  os.makedirs(wdir, exist_ok=True)
101
  return wdir
102
 
103
- def map_ui_model_to_internal(ui_name: str) -> str:
104
- return {
105
- "RealESRGAN_x4plus": "x4plus",
106
- "RealESRGAN_x4plus_anime_6B": "x4plus-anime",
107
- "RealESRGAN_x2plus": "x2plus",
108
- "RealESRNet_x4plus": "x4plus", # fallback to RRDB x4
109
- "realesr-general-x4v3": "general-x4v3", # SRVGG
110
- }.get(ui_name, "x4plus")
111
-
112
- def clamp_scale_for_model(outscale: int, model_id: str) -> int:
113
- if model_id == "x2plus": # true 2x model
114
- return 2
115
- return 4 # rest are 4x
116
-
117
- def get_realesrganer(
118
- model_id: str,
119
- tile: int,
120
- half: bool,
121
- device: str,
122
- denoise_strength: float = 0.5
123
- ) -> RealESRGANer:
124
  """
125
- Returns a RealESRGANer for:
126
- - x4plus (RRDB)
127
- - x4plus-anime (RRDB)
128
- - x2plus (RRDB)
129
- - general-x4v3 (SRVGG, supports denoise_strength via DNI weights)
130
  """
131
  wdir = _weights_dir()
132
-
133
  if model_id == "x4plus":
134
- model = build_rrdb(scale=4, num_block=23)
135
- netscale = 4
136
  url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
137
  model_path = os.path.join(wdir, "RealESRGAN_x4plus.pth")
138
  if not os.path.isfile(model_path):
139
  load_file_from_url(url=url, model_dir=wdir, progress=True)
140
- model_path_arg = model_path
141
- dni_weight = None
142
 
143
- elif model_id == "x4plus-anime":
144
- model = build_rrdb(scale=4, num_block=6)
145
- netscale = 4
146
  url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
147
  model_path = os.path.join(wdir, "RealESRGAN_x4plus_anime_6B.pth")
148
  if not os.path.isfile(model_path):
149
  load_file_from_url(url=url, model_dir=wdir, progress=True)
150
- model_path_arg = model_path
151
- dni_weight = None
152
 
153
- elif model_id == "x2plus":
154
- model = build_rrdb(scale=2, num_block=23)
155
- netscale = 2
156
  url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
157
  model_path = os.path.join(wdir, "RealESRGAN_x2plus.pth")
158
  if not os.path.isfile(model_path):
159
  load_file_from_url(url=url, model_dir=wdir, progress=True)
160
- model_path_arg = model_path
161
- dni_weight = None
162
 
163
- elif model_id == "general-x4v3":
164
- # SRVGG + two weights for DNI blend (denoise control)
 
165
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
166
  netscale = 4
167
- base_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/"
168
- base_path = os.path.join(wdir, "realesr-general-x4v3.pth")
169
- wdn_path = os.path.join(wdir, "realesr-general-wdn-x4v3.pth")
170
- if not os.path.isfile(base_path):
171
- load_file_from_url(url=base_url + "realesr-general-x4v3.pth", model_dir=wdir, progress=True)
172
- if not os.path.isfile(wdn_path):
173
- load_file_from_url(url=base_url + "realesr-general-wdn-x4v3.pth", model_dir=wdir, progress=True)
174
- model_path_arg = [base_path, wdn_path]
175
- # blend base vs denoised
176
- d = float(denoise_strength)
177
- d = max(0.0, min(1.0, d))
178
- dni_weight = [1.0 - d, d]
179
-
180
- else:
181
- raise ValueError(f"Unknown model_id: {model_id}")
182
-
183
- # Final device policy
184
- device = device if device in ("cuda", "cpu") else ("cuda" if torch.cuda.is_available() else "cpu")
185
- half = bool(half and device == "cuda")
186
-
187
- return RealESRGANer(
188
- scale=netscale,
189
- model_path=model_path_arg,
190
- dni_weight=dni_weight,
191
- model=model,
192
- tile=int(tile or 256),
193
- tile_pad=10,
194
- pre_pad=0,
195
- half=half,
196
- device=device,
197
- )
198
 
199
- # ────────────────────────────────────────────────────────
200
- # Batch upscaling helpers
201
- # ────────────────────────────────────────────────────────
202
 
203
- def _ensure_dir(p: Path) -> Path:
204
- p.mkdir(parents=True, exist_ok=True); return p
205
-
206
- def _save_zip_of_dir(dir_path: Path, zip_path: Path) -> str:
207
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
208
- for p in sorted(dir_path.glob("*.*"), key=_natural_key):
209
- if p.suffix.lower() in [".jpg", ".jpeg", ".png"]:
210
- zf.write(p, p.name)
211
- return str(zip_path)
212
-
213
- def _list_image_paths_from_upload(files: List[gr.File] | None) -> List[str]:
214
- if not files: return []
215
- return [str(Path(f.name)) for f in files if Path(f.name).suffix.lower() in [".jpg",".jpeg",".png"]]
216
 
217
- def _build_gallery_from_dir(dir_path: Path, n: int = 30) -> List[str]:
218
- paths = sorted(list(dir_path.glob("*.jpg")) + list(dir_path.glob("*.png")), key=_natural_key)
219
- return sample_paths(paths, n)
220
 
221
  # ────────────────────────────────────────────────────────
222
- # Step 2 Β· Prepare + Process (generator for streaming)
223
  # ────────────────────────────────────────────────────────
224
-
225
  def step2_prepare_sources(frames_list, uploaded_imgs, max_images):
226
  src = _list_image_paths_from_upload(uploaded_imgs) or (frames_list or [])
227
  if not src:
@@ -237,56 +189,58 @@ def step2_prepare_sources(frames_list, uploaded_imgs, max_images):
237
  total = len(src); done_idx = 0
238
  return src, str(out_dir), done_idx, total, f"Sources loaded: {total} image(s). Click 'Process Next Batch'.", render_progress(0.0, "Ready")
239
 
 
 
 
 
240
  def step2_process_next_batch(
241
  up_src_paths, up_out_dir, up_done_idx, up_total,
242
  ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size,
243
- force_cpu
244
  ):
 
 
 
 
 
245
  if not up_src_paths or not up_out_dir:
246
  yield None, None, "Load sources first.", render_progress(0.0, "Idle"), up_done_idx, up_out_dir
247
  return
248
 
 
249
  model_id = map_ui_model_to_internal(ui_model_name)
250
  scale = clamp_scale_for_model(int(outscale or 4), model_id)
251
-
252
- # Device policy
253
- device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
254
- # Precision policy
255
- if precision == "half":
256
- use_half = (device == "cuda")
257
- elif precision == "full":
258
- use_half = False
259
- else: # auto
260
- use_half = (device == "cuda")
261
-
262
  tile = int(tile or 256)
263
  batch_size = max(1, int(batch_size or 8))
 
264
 
265
- # Build upsampler (handles general-x4v3 as well)
266
- upsampler = get_realesrganer(
267
- model_id=model_id,
 
 
 
 
268
  tile=tile,
269
- half=use_half,
270
- device=device,
271
- denoise_strength=float(denoise_strength or 0.5)
 
272
  )
273
 
274
- # Optional: GFPGAN face enhancer
275
  face_enhancer = None
276
  if face_enhance:
277
  try:
278
  from gfpgan import GFPGANer
279
  face_enhancer = GFPGANer(
280
  model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
281
- upscale=scale,
282
- arch="clean",
283
- channel_multiplier=2,
284
- bg_upsampler=upsampler
285
  )
286
  except Exception as e:
287
  print("GFPGAN load failed:", e)
288
  face_enhancer = None
289
 
 
290
  start = int(up_done_idx or 0)
291
  end = min(start + batch_size, int(up_total or 0))
292
  out_dir = Path(up_out_dir)
@@ -307,23 +261,27 @@ def step2_process_next_batch(
307
  img = im.convert("RGB")
308
  cv_img = np.array(img)
309
  if face_enhancer:
310
- _, _, output = face_enhancer.enhance(cv_img, has_aligned=False, only_center_face=False, paste_back=True)
 
 
311
  else:
312
- # For general-x4v3, denoise_strength is already in DNI weights
313
- output, _ = upsampler.enhance(cv_img, outscale=scale)
314
  Image.fromarray(output).save(out_dir / (Path(fp).stem + ".jpg"), quality=95)
315
  except Exception as e:
316
  print("Upscale error:", e)
317
 
 
318
  elapsed = time.time() - t0
319
  pct_batch = (idx / total_in_batch) * 100.0
320
  eta = (total_in_batch - idx) * (elapsed / max(1, idx))
321
  label = (f"Batch: {idx}/{total_in_batch} Β· ~{eta:.1f}s ETA Β· "
322
- f"global {start+idx}/{up_total} (x{scale}, model={ui_model_name}, device={device}, half={use_half})")
323
  gallery = _build_gallery_from_dir(out_dir, 30)
324
  zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
325
  yield gallery, zip_file, label, render_progress(pct_batch, f"Upscaling {pct_batch:.0f}% (batch)"), start+idx, up_out_dir
326
 
 
327
  next_idx = end
328
  pct_global = (next_idx / up_total) * 100.0 if up_total else 100.0
329
  gallery = _build_gallery_from_dir(out_dir, 30)
@@ -333,14 +291,13 @@ def step2_process_next_batch(
333
  # ────────────────────────────────────────────────────────
334
  # UI
335
  # ────────────────────────────────────────────────────────
336
-
337
  def build_ui():
338
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
339
  gr.HTML(render_logo_html(88))
340
- gr.Markdown("Upload images and upscale with Real-ESRGAN. Process in batches with live progress. GPU if available, CPU fallback.")
341
 
342
  # States
343
- frames_state = gr.State([]) # kept for parity
344
  up_src_paths_state = gr.State([])
345
  up_out_dir_state = gr.State("")
346
  up_done_idx_state = gr.State(0)
@@ -356,38 +313,44 @@ def build_ui():
356
  value="RealESRGAN_x4plus"
357
  )
358
  denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise (only general-x4v3)")
359
- outscale = gr.Slider(1, 6, value=4, step=1, label="Resolution upscale (model-limited)")
360
  face_enhance = gr.Checkbox(value=False, label="Face Enhancement (GFPGAN)")
361
  with gr.Row():
362
  tile = gr.Number(value=256, label="Tile size (try 128 if OOM; 0=auto)")
363
- precision = gr.Dropdown(["auto", "half", "full"], value="auto", label="Precision (GPU=half, CPU=full)")
364
- force_cpu = gr.Checkbox(value=False, label="Zero-GPU Mode (force CPU)")
365
  with gr.Row():
366
  batch_size = gr.Number(value=12, precision=0, label="Batch size per click")
367
  max_images = gr.Number(value=0, precision=0, label="Max images to process (0 = all)")
368
 
369
  with gr.Row():
370
  btn_prepare = gr.Button("Load / Reset Sources", variant="secondary")
371
- btn_next = gr.Button("Process Next Batch", variant="primary")
372
 
373
  prog = gr.HTML(render_progress(0.0, "Idle"))
374
- gallery_up = gr.Gallery(label="Upscaled preview (30 sampled)", columns=6, height=480)
375
  zip_up = gr.File(label="Download upscaled ZIP")
376
  details = gr.Markdown("")
377
 
 
378
  btn_prepare.click(
379
  step2_prepare_sources,
380
  inputs=[frames_state, imgs_override, max_images],
381
  outputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state, details, prog]
382
  )
383
 
 
384
  btn_next.click(
385
  step2_process_next_batch,
386
  inputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state,
387
- ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size, force_cpu],
388
  outputs=[gallery_up, zip_up, details, prog, up_done_idx_state, up_out_dir_state]
389
  )
390
 
 
 
 
 
 
391
  return demo
392
 
393
  if __name__ == "__main__":
 
1
+ # app.py β€” MjΓΆlnir Β· Upscale Images (ZeroGPU-ready, batch click, with logo)
2
 
3
+ # ---- TorchVision shim (so basicsr can import even without torchvision) ----
4
  import sys, types
5
  try:
6
  import torchvision.transforms.functional_tensor as _ft # noqa: F401
 
17
  return torch.stack([gray, gray, gray], dim=-3) if num_output_channels == 3 else gray.unsqueeze(-3)
18
  _mod.rgb_to_grayscale = rgb_to_grayscale
19
  sys.modules["torchvision.transforms.functional_tensor"] = _mod
20
+ # --------------------------------------------------------------------------
21
 
22
  import os, time, zipfile, tempfile, shutil, base64
23
  from pathlib import Path
24
+ from typing import List, Optional, Tuple
25
+
26
  import gradio as gr
27
  import numpy as np
28
  import cv2
29
  from PIL import Image
30
 
31
+ # ZeroGPU hook
32
+ import spaces
33
+
34
+ # Real-ESRGAN / basicsr (importing these at top is OK; just don't touch CUDA here)
35
  from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
36
  from basicsr.utils.download_util import load_file_from_url
37
  from realesrgan import RealESRGANer
38
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
39
 
40
  # ────────────────────────────────────────────────────────
41
+ # Branding (logo)
42
  # ────────────────────────────────────────────────────────
 
 
 
 
 
 
43
  def try_load_logo_b64() -> str:
44
  try:
45
  with open("bifrost_logo.png", "rb") as f:
 
55
  {img}
56
  <div>
57
  <div style="font-size:1.6rem;font-weight:800;">MjΓΆlnir Β· Upscale Images</div>
58
+ <div style="opacity:0.8;">Real-ESRGAN Β· batch click Β· live progress Β· ZeroGPU</div>
59
  </div>
60
  </div>
61
  <hr>
62
  """
63
 
64
+ # ────────────────────────────────────────────────────────
65
+ # Small helpers
66
+ # ────────────────────────────────────────────────────────
67
+ import re
68
  _num = re.compile(r'(\d+)')
69
  def _natural_key(p: Path | str):
70
  s = str(p)
 
78
  step = (total - 1) / (n - 1); idxs = [round(i * step) for i in range(n)]
79
  out, seen = [], set()
80
  for i in idxs:
81
+ if int(i) not in seen:
82
+ out.append(str(paths[int(i)])); seen.add(int(i))
 
83
  return out
84
 
85
  def render_progress(pct: float, label: str = "") -> str:
 
88
  <div style="height:100%;width:{pct:.1f}%;background:#3b82f6;"></div></div>
89
  <div style="font-size:12px;opacity:.8;margin-top:4px;">{label} {pct:.1f}%</div>'''
90
 
91
+ def _ensure_dir(p: Path) -> Path:
92
+ p.mkdir(parents=True, exist_ok=True); return p
93
+
94
+ def _save_zip_of_dir(dir_path: Path, zip_path: Path) -> str:
95
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
96
+ for p in sorted(dir_path.glob("*.*"), key=_natural_key):
97
+ if p.suffix.lower() in [".jpg", ".jpeg", ".png"]:
98
+ zf.write(p, p.name)
99
+ return str(zip_path)
100
+
101
+ def _list_image_paths_from_upload(files: List[gr.File] | None) -> List[str]:
102
+ if not files: return []
103
+ return [str(Path(f.name)) for f in files if Path(f.name).suffix.lower() in [".jpg",".jpeg",".png"]]
104
+
105
+ def _build_gallery_from_dir(dir_path: Path, n: int = 30) -> List[str]:
106
+ paths = sorted(list(dir_path.glob("*.jpg")) + list(dir_path.glob("*.png")), key=_natural_key)
107
+ return sample_paths(paths, n)
108
+
109
  # ────────────────────────────────────────────────────────
110
+ # Models & weight management (CPU-safe; no CUDA used here)
111
  # ────────────────────────────────────────────────────────
 
112
  def build_rrdb(scale: int, num_block: int):
113
  return _RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=num_block, num_grow_ch=32, scale=scale)
114
 
 
117
  os.makedirs(wdir, exist_ok=True)
118
  return wdir
119
 
120
+ def ensure_weights(model_id: str) -> Tuple[object, int, str, Optional[list]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  """
122
+ Returns: (model, netscale, model_path_or_list, dni_weight_placeholder)
 
 
 
 
123
  """
124
  wdir = _weights_dir()
 
125
  if model_id == "x4plus":
126
+ model = build_rrdb(scale=4, num_block=23); netscale = 4
 
127
  url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
128
  model_path = os.path.join(wdir, "RealESRGAN_x4plus.pth")
129
  if not os.path.isfile(model_path):
130
  load_file_from_url(url=url, model_dir=wdir, progress=True)
131
+ return model, netscale, model_path, None
 
132
 
133
+ if model_id == "x4plus-anime":
134
+ model = build_rrdb(scale=4, num_block=6); netscale = 4
 
135
  url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
136
  model_path = os.path.join(wdir, "RealESRGAN_x4plus_anime_6B.pth")
137
  if not os.path.isfile(model_path):
138
  load_file_from_url(url=url, model_dir=wdir, progress=True)
139
+ return model, netscale, model_path, None
 
140
 
141
+ if model_id == "x2plus":
142
+ model = build_rrdb(scale=2, num_block=23); netscale = 2
 
143
  url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
144
  model_path = os.path.join(wdir, "RealESRGAN_x2plus.pth")
145
  if not os.path.isfile(model_path):
146
  load_file_from_url(url=url, model_dir=wdir, progress=True)
147
+ return model, netscale, model_path, None
 
148
 
149
+ # For UI compatibility only: map general-x4v3 to x4plus backend (or implement SRVGG if you prefer)
150
+ if model_id == "general-x4v3":
151
+ # Proper SRVGG model:
152
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
153
  netscale = 4
154
+ base_pth = os.path.join(wdir, "realesr-general-x4v3.pth")
155
+ if not os.path.isfile(base_pth):
156
+ load_file_from_url(url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
157
+ model_dir=wdir, progress=True)
158
+ return model, netscale, base_pth, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ raise ValueError(f"Unknown model_id: {model_id}")
 
 
161
 
162
+ def map_ui_model_to_internal(ui_name: str) -> str:
163
+ return {
164
+ "RealESRGAN_x4plus": "x4plus",
165
+ "RealESRGAN_x4plus_anime_6B": "x4plus-anime",
166
+ "RealESRGAN_x2plus": "x2plus",
167
+ "RealESRNet_x4plus": "x4plus", # fallback
168
+ "realesr-general-x4v3": "general-x4v3", # SRVGG
169
+ }.get(ui_name, "x4plus")
 
 
 
 
 
170
 
171
+ def clamp_scale_for_model(outscale: int, model_id: str) -> int:
172
+ return 2 if model_id == "x2plus" else 4
 
173
 
174
  # ────────────────────────────────────────────────────────
175
+ # Step 2 Β· Prepare sources (CPU)
176
  # ────────────────────────────────────────────────────────
 
177
  def step2_prepare_sources(frames_list, uploaded_imgs, max_images):
178
  src = _list_image_paths_from_upload(uploaded_imgs) or (frames_list or [])
179
  if not src:
 
189
  total = len(src); done_idx = 0
190
  return src, str(out_dir), done_idx, total, f"Sources loaded: {total} image(s). Click 'Process Next Batch'.", render_progress(0.0, "Ready")
191
 
192
+ # ────────────────────────────────────────────────────────
193
+ # Step 2 Β· Process next batch (GPU) β€” ZeroGPU entry point
194
+ # ────────────────────────────────────────────────────────
195
+ @spaces.GPU
196
  def step2_process_next_batch(
197
  up_src_paths, up_out_dir, up_done_idx, up_total,
198
  ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size,
 
199
  ):
200
+ """
201
+ Runs on ZeroGPU. Heavy parts (model load + enhance) are done inside this function.
202
+ Yields progress after each image in the current batch.
203
+ """
204
+ # Validate inputs
205
  if not up_src_paths or not up_out_dir:
206
  yield None, None, "Load sources first.", render_progress(0.0, "Idle"), up_done_idx, up_out_dir
207
  return
208
 
209
+ # Resolve model & scale
210
  model_id = map_ui_model_to_internal(ui_model_name)
211
  scale = clamp_scale_for_model(int(outscale or 4), model_id)
 
 
 
 
 
 
 
 
 
 
 
212
  tile = int(tile or 256)
213
  batch_size = max(1, int(batch_size or 8))
214
+ use_half = (precision == "half") # we'll honor this on CUDA only
215
 
216
+ # Ensure weights & build model (still CPU-safe) then instantiate ESRGANer on GPU
217
+ model, netscale, model_path, dni_weight = ensure_weights(model_id)
218
+ upsampler = RealESRGANer(
219
+ scale=netscale,
220
+ model_path=model_path,
221
+ dni_weight=dni_weight,
222
+ model=model,
223
  tile=tile,
224
+ tile_pad=10,
225
+ pre_pad=10,
226
+ half=use_half, # when ZeroGPU gives CUDA, this enables fp16
227
+ gpu_id=0 # request the single available GPU
228
  )
229
 
230
+ # Optional face enhancer (kept off by default as it adds weight download)
231
  face_enhancer = None
232
  if face_enhance:
233
  try:
234
  from gfpgan import GFPGANer
235
  face_enhancer = GFPGANer(
236
  model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
237
+ upscale=scale, arch="clean", channel_multiplier=2, bg_upsampler=upsampler
 
 
 
238
  )
239
  except Exception as e:
240
  print("GFPGAN load failed:", e)
241
  face_enhancer = None
242
 
243
+ # Batch window
244
  start = int(up_done_idx or 0)
245
  end = min(start + batch_size, int(up_total or 0))
246
  out_dir = Path(up_out_dir)
 
261
  img = im.convert("RGB")
262
  cv_img = np.array(img)
263
  if face_enhancer:
264
+ _, _, output = face_enhancer.enhance(
265
+ cv_img, has_aligned=False, only_center_face=False, paste_back=True
266
+ )
267
  else:
268
+ # denoise_strength has effect with general-x4v3; harmless otherwise
269
+ output, _ = upsampler.enhance(cv_img, outscale=scale, denoise_strength=float(denoise_strength or 0.5))
270
  Image.fromarray(output).save(out_dir / (Path(fp).stem + ".jpg"), quality=95)
271
  except Exception as e:
272
  print("Upscale error:", e)
273
 
274
+ # Progress for THIS batch
275
  elapsed = time.time() - t0
276
  pct_batch = (idx / total_in_batch) * 100.0
277
  eta = (total_in_batch - idx) * (elapsed / max(1, idx))
278
  label = (f"Batch: {idx}/{total_in_batch} Β· ~{eta:.1f}s ETA Β· "
279
+ f"global {start+idx}/{up_total} (x{scale}, model={ui_model_name}, tile={tile}, half={use_half})")
280
  gallery = _build_gallery_from_dir(out_dir, 30)
281
  zip_file = _save_zip_of_dir(out_dir, Path(out_dir.parent) / "upscaled.zip")
282
  yield gallery, zip_file, label, render_progress(pct_batch, f"Upscaling {pct_batch:.0f}% (batch)"), start+idx, up_out_dir
283
 
284
+ # Batch complete
285
  next_idx = end
286
  pct_global = (next_idx / up_total) * 100.0 if up_total else 100.0
287
  gallery = _build_gallery_from_dir(out_dir, 30)
 
291
  # ────────────────────────────────────────────────────────
292
  # UI
293
  # ────────────────────────────────────────────────────────
 
294
  def build_ui():
295
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
296
  gr.HTML(render_logo_html(88))
297
+ gr.Markdown("Upload images and upscale with Real-ESRGAN. **ZeroGPU** spins up a GPU **only** during batch processing.")
298
 
299
  # States
300
+ frames_state = gr.State([]) # present for parity; not used here
301
  up_src_paths_state = gr.State([])
302
  up_out_dir_state = gr.State("")
303
  up_done_idx_state = gr.State(0)
 
313
  value="RealESRGAN_x4plus"
314
  )
315
  denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise (only general-x4v3)")
316
+ outscale = gr.Slider(1, 6, value=4, step=1, label="Resolution upscale")
317
  face_enhance = gr.Checkbox(value=False, label="Face Enhancement (GFPGAN)")
318
  with gr.Row():
319
  tile = gr.Number(value=256, label="Tile size (try 128 if OOM; 0=auto)")
320
+ precision = gr.Dropdown(["auto", "half", "full"], value="half", label="Precision (GPU: half, CPU ignored)")
 
321
  with gr.Row():
322
  batch_size = gr.Number(value=12, precision=0, label="Batch size per click")
323
  max_images = gr.Number(value=0, precision=0, label="Max images to process (0 = all)")
324
 
325
  with gr.Row():
326
  btn_prepare = gr.Button("Load / Reset Sources", variant="secondary")
327
+ btn_next = gr.Button("Process Next Batch (uses GPU)", variant="primary")
328
 
329
  prog = gr.HTML(render_progress(0.0, "Idle"))
330
+ gallery_up = gr.Gallery(label="Upscaled preview (sampled 30)", columns=6, height=480)
331
  zip_up = gr.File(label="Download upscaled ZIP")
332
  details = gr.Markdown("")
333
 
334
+ # 1) load/reset sources (CPU)
335
  btn_prepare.click(
336
  step2_prepare_sources,
337
  inputs=[frames_state, imgs_override, max_images],
338
  outputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state, details, prog]
339
  )
340
 
341
+ # 2) process one batch per click (GPU)
342
  btn_next.click(
343
  step2_process_next_batch,
344
  inputs=[up_src_paths_state, up_out_dir_state, up_done_idx_state, up_total_state,
345
+ ui_model_name, outscale, tile, precision, denoise_strength, face_enhance, batch_size],
346
  outputs=[gallery_up, zip_up, details, prog, up_done_idx_state, up_out_dir_state]
347
  )
348
 
349
+ gr.Markdown(
350
+ "> ℹ️ **ZeroGPU tips**: Larger tiles are faster but use more VRAM. If you hit OOM, try `tile=128`, "
351
+ "`batch size=4–8`, and keep `Precision=half`."
352
+ )
353
+
354
  return demo
355
 
356
  if __name__ == "__main__":