JS6969 commited on
Commit
2c452c9
·
verified ·
1 Parent(s): 78e2590

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -38
app.py CHANGED
@@ -9,8 +9,22 @@
9
  # - Prefix defaults to input video filename if left blank
10
  # =============================
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import os
13
  import re
 
14
  import json
15
  import math
16
  import time
@@ -18,6 +32,7 @@ import shutil
18
  import zipfile
19
  import tempfile
20
  import subprocess
 
21
  from pathlib import Path
22
  from typing import List, Optional, Tuple
23
 
@@ -25,6 +40,12 @@ import gradio as gr
25
  import numpy as np
26
  from PIL import Image
27
 
 
 
 
 
 
 
28
  # ─────────────────────────────────────────────────────────────
29
  # System checks & deps
30
  # ─────────────────────────────────────────────────────────────
@@ -47,19 +68,30 @@ MISSING_MSG = (
47
  else:
48
  MISSING_MSG = ""
49
 
50
- # Try to import Real-ESRGAN stack
51
- try:
52
- from realesrgan import RealESRGANer
53
- from basicsr.archs.rrdbnet_arch import RRDBNet
54
- HAVE_REALESRGAN = True
55
- except Exception as e:
56
- HAVE_REALESRGAN = False
57
- REAL_ERR = str(e)
58
-
59
  # ─────────────────────────────────────────────────────────────
60
  # Helpers
61
  # ─────────────────────────────────────────────────────────────
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def sample_paths(paths: List[Path] | List[str], n: int = 30) -> List[str]:
64
  """Return up to n items sampled evenly across the list, preserving order (as strings)."""
65
  if not paths:
@@ -185,30 +217,124 @@ def build_ffmpeg_extract(
185
  return cmd
186
 
187
 
188
- def get_realesrganer(model_name: str, scale: int, tile: int, half: bool, device: str = "cpu"):
189
- if not HAVE_REALESRGAN:
190
- raise RuntimeError("realesrgan is not installed. See requirements.txt (realesrgan, basicsr, torch, numpy, scipy, scikit-image).")
191
- if model_name in ("x4plus", "x4plus-anime"):
192
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
193
- model_scale = 4
194
- elif model_name == "x2plus":
195
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
196
- model_scale = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  else:
198
- raise ValueError("Unknown Real-ESRGAN model")
199
- if scale not in (2, 4):
200
- scale = model_scale
 
 
 
 
 
 
 
 
201
  upsampler = RealESRGANer(
202
- scale=model_scale,
203
- model_path=None,
 
204
  model=model,
205
- tile=tile or 0,
206
  tile_pad=10,
207
- pre_pad=0,
208
- half=half,
209
- device=device,
210
  )
211
- return upsampler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  def render_progress(pct: float, label: str = "") -> str:
@@ -323,16 +449,15 @@ def save_uploaded_images(files: List[gr.File] | None, prefix: str = "upload") ->
323
 
324
  def step2_upscale(
325
  frames_list: List[str] | None,
326
- model_name: str,
327
- scale: int,
328
  tile: int,
329
  precision: str,
330
  prog_html: str,
331
  uploaded_imgs: List[gr.File] | None,
 
 
332
  ):
333
- if not HAVE_REALESRGAN:
334
- msg = "Real-ESRGAN not available. Ensure requirements.txt includes: --prefer-binary, numpy==1.26.4, scipy==1.11.4, scikit-image==0.22.0, opencv-python-headless, torch==2.2.2, realesrgan==0.3.0, basicsr==1.4.2, pillow, gradio."
335
- return None, None, msg, prog_html
336
 
337
  # decide source: uploaded images take priority, else frames from step 1
338
  if uploaded_imgs and len(uploaded_imgs) > 0:
@@ -344,6 +469,10 @@ def step2_upscale(
344
  if not src_paths:
345
  return None, None, "No images provided. Upload files or run Step 1 first.", prog_html
346
 
 
 
 
 
347
  device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
348
  half = (precision == "half") and (device == "cuda")
349
  upsampler = get_realesrganer(model_name, scale, tile, half, device=device)
@@ -377,6 +506,9 @@ def step2_upscale(
377
  for p in up_paths:
378
  zf.write(p, p.name)
379
 
 
 
 
380
  return gallery, str(zip_path), f"Upscaled: {len(up_paths)}", render_progress(100.0, "Upscaling complete")
381
 
382
  # ───────────────── Encode (Step 3) — supports uploaded frames/ZIP & optional audio source
@@ -649,16 +781,46 @@ def build_ui():
649
  # STEP 2 — Upscale
650
  with gr.Tab("Step 2 · Upscale Frames"):
651
  if not HAVE_REALESRGAN:
652
- gr.Markdown("⚠️ Upscaling disabled. Install dependencies in requirements.txt (see notes in code). Error: " + (REAL_ERR if 'REAL_ERR' in globals() else ""))
 
653
  gr.Markdown("Use frames from Step 1 **or** upload images below.")
654
- imgs_override = gr.Files(label="Upload images to upscale (JPG/PNG)", file_types=[".jpg", ".jpeg", ".png"], type="filepath")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  with gr.Row():
656
- model_name = gr.Dropdown(["x4plus", "x4plus-anime", "x2plus"], value="x4plus", label="Model")
657
- scale = gr.Dropdown([2, 4], value=4, label="Output scale")
658
  tile = gr.Number(value=0, label="Tile size (0 = auto)")
659
  precision = gr.Dropdown(["auto", "half", "full"], value="auto", label="Precision (GPU=half, CPU=full)")
 
660
  with gr.Row():
661
  btn_upscale = gr.Button("Step 2: Upscale", variant="primary")
 
662
  prog2 = gr.HTML(render_progress(0.0, "Idle"))
663
  gallery_up = gr.Gallery(label="Upscaled preview (30 sampled)", columns=6, height=480)
664
  zip_up = gr.File(label="Download upscaled ZIP")
@@ -666,7 +828,7 @@ def build_ui():
666
 
667
  btn_upscale.click(
668
  step2_upscale,
669
- inputs=[frames_state, model_name, scale, tile, precision, prog2, imgs_override],
670
  outputs=[gallery_up, zip_up, details2, prog2],
671
  )
672
 
 
9
  # - Prefix defaults to input video filename if left blank
10
  # =============================
11
 
12
+
13
+ import sys, types
14
+ try:
15
+ import torchvision.transforms.functional_tensor as _ft # noqa: F401
16
+ except Exception:
17
+ from torchvision.transforms import functional as _F
18
+ _mod = types.ModuleType("torchvision.transforms.functional_tensor")
19
+ _mod.rgb_to_grayscale = _F.rgb_to_grayscale
20
+ sys.modules["torchvision.transforms.functional_tensor"] = _mod
21
+
22
+ # ────────────────────────────────────────────────────────
23
+ # Standard imports
24
+ # ────────────────────────────────────────────────────────
25
  import os
26
  import re
27
+ import cv2
28
  import json
29
  import math
30
  import time
 
32
  import zipfile
33
  import tempfile
34
  import subprocess
35
+ import inspect
36
  from pathlib import Path
37
  from typing import List, Optional, Tuple
38
 
 
40
  import numpy as np
41
  from PIL import Image
42
 
43
+ from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
44
+ from basicsr.utils.download_util import load_file_from_url
45
+
46
+ from realesrgan import RealESRGANer
47
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
48
+
49
  # ─────────────────────────────────────────────────────────────
50
  # System checks & deps
51
  # ─────────────────────────────────────────────────────────────
 
68
  else:
69
  MISSING_MSG = ""
70
 
 
 
 
 
 
 
 
 
 
71
  # ─────────────────────────────────────────────────────────────
72
  # Helpers
73
  # ─────────────────────────────────────────────────────────────
74
 
75
+ # Map UI model names (demo) to our internal model IDs
76
+ def map_ui_model_to_internal(ui_name: str) -> str:
77
+ mapping = {
78
+ "RealESRGAN_x4plus": "x4plus",
79
+ "RealESRGAN_x4plus_anime_6B": "x4plus-anime",
80
+ "RealESRGAN_x2plus": "x2plus",
81
+ # Unsupported in our current RRDBNet wiring – fallback:
82
+ "RealESRNet_x4plus": "x4plus",
83
+ "realesr-general-x4v3": "x4plus",
84
+ }
85
+ return mapping.get(ui_name, "x4plus")
86
+
87
+ def clamp_scale_for_model(outscale: int, model_id: str) -> int:
88
+ # Our current models are ×2 or ×4 only.
89
+ if model_id == "x2plus":
90
+ return 2
91
+ # For x4plus / x4plus-anime, force 4 (ignore 5–6)
92
+ return 4
93
+
94
+
95
  def sample_paths(paths: List[Path] | List[str], n: int = 30) -> List[str]:
96
  """Return up to n items sampled evenly across the list, preserving order (as strings)."""
97
  if not paths:
 
217
  return cmd
218
 
219
 
220
+ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
221
+ if img is None:
222
+ return
223
+
224
+ # ----- Select backbone + weights -----
225
+ if model_name == 'RealESRGAN_x4plus':
226
+ model = build_rrdb(scale=4, num_block=23); netscale = 4
227
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
228
+
229
+ elif model_name == 'RealESRNet_x4plus':
230
+ model = build_rrdb(scale=4, num_block=23); netscale = 4
231
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
232
+
233
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
234
+ model = build_rrdb(scale=4, num_block=6); netscale = 4
235
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
236
+
237
+ elif model_name == 'RealESRGAN_x2plus':
238
+ model = build_rrdb(scale=2, num_block=23); netscale = 2
239
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
240
+
241
+ elif model_name == 'realesr-general-x4v3':
242
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'); netscale = 4
243
+ file_url = [
244
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
245
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
246
+ ]
247
+ else:
248
+ raise ValueError(f"Unknown model: {model_name}")
249
+
250
+ # ----- Ensure weights on disk -----
251
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
252
+ weights_dir = os.path.join(ROOT_DIR, 'weights')
253
+ os.makedirs(weights_dir, exist_ok=True)
254
+
255
+ for url in file_url:
256
+ fname = os.path.basename(url)
257
+ local_path = os.path.join(weights_dir, fname)
258
+ if not os.path.isfile(local_path):
259
+ load_file_from_url(url=url, model_dir=weights_dir, progress=True)
260
+
261
+ if model_name == 'realesr-general-x4v3':
262
+ base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
263
+ wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth')
264
+ model_path = [base_path, wdn_path]
265
+ denoise_strength = float(denoise_strength)
266
+ dni_weight = [1.0 - denoise_strength, denoise_strength] # base, WDN
267
  else:
268
+ model_path = os.path.join(weights_dir, f"{model_name}.pth")
269
+ dni_weight = None
270
+
271
+ # ----- CUDA / precision / tiling -----
272
+ use_cuda = False
273
+ try:
274
+ use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
275
+ except Exception:
276
+ use_cuda = False
277
+ gpu_id = 0 if use_cuda else None
278
+
279
  upsampler = RealESRGANer(
280
+ scale=netscale,
281
+ model_path=model_path,
282
+ dni_weight=dni_weight,
283
  model=model,
284
+ tile=256, # VRAM-safe default; lower to 128 if OOM
285
  tile_pad=10,
286
+ pre_pad=10,
287
+ half=bool(use_cuda),
288
+ gpu_id=gpu_id
289
  )
290
+
291
+ # ----- Optional face enhancement -----
292
+ face_enhancer = None
293
+ if face_enhance:
294
+ from gfpgan import GFPGANer
295
+ face_enhancer = GFPGANer(
296
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
297
+ upscale=outscale,
298
+ arch='clean',
299
+ channel_multiplier=2,
300
+ bg_upsampler=upsampler
301
+ )
302
+
303
+ # ----- PIL -> cv2 -----
304
+ cv_img = numpy.array(img)
305
+ if cv_img.ndim == 3 and cv_img.shape[2] == 4:
306
+ cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
307
+ else:
308
+ cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
309
+
310
+ # ----- Enhance -----
311
+ try:
312
+ if face_enhancer:
313
+ _, _, output = face_enhancer.enhance(cv_img, has_aligned=False, only_center_face=False, paste_back=True)
314
+ else:
315
+ output, _ = upsampler.enhance(cv_img, outscale=int(outscale))
316
+ except RuntimeError as error:
317
+ print('Error', error)
318
+ print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
319
+ return None
320
+
321
+ # ----- cv2 -> display ndarray, also save -----
322
+ if output.ndim == 3 and output.shape[2] == 4:
323
+ display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
324
+ extension = 'png'
325
+ else:
326
+ display_img = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
327
+ extension = 'jpg'
328
+
329
+ out_filename = f"output_{rnd_string(8)}.{extension}"
330
+ try:
331
+ cv2.imwrite(out_filename, output)
332
+ global last_file
333
+ last_file = out_filename
334
+ except Exception as e:
335
+ print("Save error:", e)
336
+
337
+ return display_img
338
 
339
 
340
  def render_progress(pct: float, label: str = "") -> str:
 
449
 
450
  def step2_upscale(
451
  frames_list: List[str] | None,
452
+ ui_model_name: str,
453
+ outscale: int,
454
  tile: int,
455
  precision: str,
456
  prog_html: str,
457
  uploaded_imgs: List[gr.File] | None,
458
+ denoise_strength: float = 0.5, # NEW: accepted but currently unused
459
+ face_enhance: bool = False, # NEW: accepted but currently unused (needs GFPGAN)
460
  ):
 
 
 
461
 
462
  # decide source: uploaded images take priority, else frames from step 1
463
  if uploaded_imgs and len(uploaded_imgs) > 0:
 
469
  if not src_paths:
470
  return None, None, "No images provided. Upload files or run Step 1 first.", prog_html
471
 
472
+ # Map demo model -> internal, clamp scale to supported values
473
+ model_id = map_ui_model_to_internal(ui_model_name)
474
+ scale = clamp_scale_for_model(int(outscale or 4), model_id)
475
+
476
  device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
477
  half = (precision == "half") and (device == "cuda")
478
  upsampler = get_realesrganer(model_name, scale, tile, half, device=device)
 
506
  for p in up_paths:
507
  zf.write(p, p.name)
508
 
509
+ detail = (f"Upscaled: {len(up_paths)} | Model: {ui_model_name}→{model_id} | "
510
+ f"Scale: x{scale} | Tile: {tile} | Precision: {precision}")
511
+
512
  return gallery, str(zip_path), f"Upscaled: {len(up_paths)}", render_progress(100.0, "Upscaling complete")
513
 
514
  # ───────────────── Encode (Step 3) — supports uploaded frames/ZIP & optional audio source
 
781
  # STEP 2 — Upscale
782
  with gr.Tab("Step 2 · Upscale Frames"):
783
  if not HAVE_REALESRGAN:
784
+ gr.Markdown("⚠️ Upscaling disabled. Install dependencies in requirements.txt (realesrgan, basicsr, torch, etc.).")
785
+
786
  gr.Markdown("Use frames from Step 1 **or** upload images below.")
787
+ imgs_override = gr.Files(
788
+ label="Upload images to upscale (JPG/PNG)",
789
+ file_types=[".jpg", ".jpeg", ".png"],
790
+ type="filepath"
791
+ )
792
+
793
+ with gr.Accordion("Upscaling options", open=True):
794
+ with gr.Row():
795
+ ui_model_name = gr.Dropdown(
796
+ label="Upscaler model",
797
+ choices=[
798
+ "RealESRGAN_x4plus",
799
+ "RealESRNet_x4plus",
800
+ "RealESRGAN_x4plus_anime_6B",
801
+ "RealESRGAN_x2plus",
802
+ "realesr-general-x4v3",
803
+ ],
804
+ value="RealESRGAN_x4plus",
805
+ show_label=True
806
+ )
807
+ denoise_strength = gr.Slider(
808
+ label="Denoise Strength (only for realesr-general-x4v3)",
809
+ minimum=0, maximum=1, step=0.1, value=0.5
810
+ )
811
+ outscale = gr.Slider(
812
+ label="Resolution upscale",
813
+ minimum=1, maximum=6, step=1, value=4, show_label=True
814
+ )
815
+ face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
816
+
817
  with gr.Row():
 
 
818
  tile = gr.Number(value=0, label="Tile size (0 = auto)")
819
  precision = gr.Dropdown(["auto", "half", "full"], value="auto", label="Precision (GPU=half, CPU=full)")
820
+
821
  with gr.Row():
822
  btn_upscale = gr.Button("Step 2: Upscale", variant="primary")
823
+
824
  prog2 = gr.HTML(render_progress(0.0, "Idle"))
825
  gallery_up = gr.Gallery(label="Upscaled preview (30 sampled)", columns=6, height=480)
826
  zip_up = gr.File(label="Download upscaled ZIP")
 
828
 
829
  btn_upscale.click(
830
  step2_upscale,
831
+ inputs=[frames_state, ui_model_name, outscale, tile, precision, prog2, imgs_override, denoise_strength, face_enhance],
832
  outputs=[gallery_up, zip_up, details2, prog2],
833
  )
834