JS6969 commited on
Commit
93e3fa3
Β·
verified Β·
1 Parent(s): f9b1d15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -52
app.py CHANGED
@@ -1,8 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import cv2
3
  import numpy
4
  import os
5
  import random
 
6
  from basicsr.archs.rrdbnet_arch import RRDBNet
7
  from basicsr.utils.download_util import load_file_from_url
8
 
@@ -20,14 +39,11 @@ img_mode = "RGBA"
20
  # Utilities
21
  # ────────────────────────────────────────────────────────
22
  def rnd_string(x: int) -> str:
23
- """Returns a string of 'x' random characters."""
24
  characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
25
- result = "".join((random.choice(characters)) for _ in range(x))
26
- return result
27
 
28
 
29
  def reset():
30
- """Resets the Image components and deletes the last processed image."""
31
  global last_file
32
  if last_file:
33
  try:
@@ -40,10 +56,6 @@ def reset():
40
 
41
 
42
  def has_transparency(img):
43
- """
44
- Check for transparency in a PIL image.
45
- https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
46
- """
47
  if img.info.get("transparency", None) is not None:
48
  return True
49
  if img.mode == "P":
@@ -59,19 +71,13 @@ def has_transparency(img):
59
 
60
 
61
  def image_properties(img):
62
- """Return resolution & color mode of the input image; set global img_mode."""
63
  global img_mode
64
  if img:
65
- if has_transparency(img):
66
- img_mode = "RGBA"
67
- else:
68
- img_mode = "RGB"
69
- properties = f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
70
- return properties
71
 
72
 
73
  def model_tip_text(model_name: str) -> str:
74
- """Return human-friendly guidance for the chosen model."""
75
  tips = {
76
  "RealESRGAN_x4plus": (
77
  "**RealESRGAN_x4plus (4Γ—)** β€” Best for photoreal images (portraits, landscapes). "
@@ -101,50 +107,40 @@ def model_tip_text(model_name: str) -> str:
101
  # Core upscaling
102
  # ────────────────────────────────────────────────────────
103
  def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
104
- """Real-ESRGAN function to restore (and upscale) images with robust defaults."""
105
  if img is None:
106
  return
107
 
108
  # ----- Select backbone + weights -----
109
- if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
110
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
111
- netscale = 4
112
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
113
 
114
- elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
115
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
116
- netscale = 4
117
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
118
 
119
- elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
120
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
121
- netscale = 4
122
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
123
 
124
- elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
125
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
126
- netscale = 2
127
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
128
 
129
- elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
130
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
131
- netscale = 4
132
- # We'll ensure BOTH base and WDN weights exist; order matters for DNI.
133
  file_url = [
134
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
135
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
136
  ]
137
-
138
  else:
139
  raise ValueError(f"Unknown model: {model_name}")
140
 
141
- # ----- Ensure weights are on disk -----
142
- # For the general-x4v3 case we download both; for others single file is fine.
143
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
144
  weights_dir = os.path.join(ROOT_DIR, 'weights')
145
  os.makedirs(weights_dir, exist_ok=True)
146
 
147
- # Track model paths
148
  local_paths = []
149
  for url in file_url:
150
  fname = os.path.basename(url)
@@ -153,27 +149,22 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
153
  local_path = load_file_from_url(url=url, model_dir=weights_dir, progress=True)
154
  local_paths.append(local_path)
155
 
156
- # Default path(s)
157
  if model_name == 'realesr-general-x4v3':
158
- # Order: [base, wdn] then set DNI weights accordingly
159
  base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
160
- wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth')
161
  model_path = [base_path, wdn_path]
162
  denoise_strength = float(denoise_strength)
163
- # Weight for WDN equals denoise_strength (cleaner); base gets the remainder
164
- dni_weight = [1.0 - denoise_strength, denoise_strength]
165
  else:
166
  model_path = os.path.join(weights_dir, f"{model_name}.pth")
167
  dni_weight = None
168
 
169
  # ----- CUDA / precision / tiling -----
170
- # Be defensive: cv2.cuda may not exist in CPU-only builds.
171
  use_cuda = False
172
  try:
173
  use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
174
  except Exception:
175
  use_cuda = False
176
-
177
  gpu_id = 0 if use_cuda else None
178
 
179
  upsampler = RealESRGANer(
@@ -181,10 +172,10 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
181
  model_path=model_path,
182
  dni_weight=dni_weight,
183
  model=model,
184
- tile=256, # Safe VRAM default; increase if you have headroom
185
  tile_pad=10,
186
  pre_pad=10,
187
- half=bool(use_cuda), # FP16 on GPU
188
  gpu_id=gpu_id
189
  )
190
 
@@ -200,7 +191,7 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
200
  bg_upsampler=upsampler
201
  )
202
 
203
- # ----- Convert PIL -> cv2 (handle RGB/RGBA) -----
204
  cv_img = numpy.array(img)
205
  if cv_img.ndim == 3 and cv_img.shape[2] == 4:
206
  cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
@@ -218,7 +209,7 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
218
  print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
219
  return None
220
 
221
- # ----- cv2 -> RGBA/RGB for Gradio, also save -----
222
  if output.ndim == 3 and output.shape[2] == 4:
223
  display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
224
  extension = 'png'
@@ -234,7 +225,7 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
234
  except Exception as e:
235
  print("Save error:", e)
236
 
237
- return display_img # ndarray so Gradio displays immediately
238
 
239
 
240
  # ────────────────────────────────────────────────────────
@@ -255,7 +246,7 @@ def main():
255
  "RealESRGAN_x2plus",
256
  "realesr-general-x4v3",
257
  ],
258
- value="RealESRGAN_x4plus", # photoreal default
259
  show_label=True
260
  )
261
  denoise_strength = gr.Slider(
@@ -268,7 +259,6 @@ def main():
268
  )
269
  face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
270
 
271
- # Model tips panel (auto-updates)
272
  model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus"))
273
 
274
  with gr.Row():
@@ -281,7 +271,6 @@ def main():
281
  reset_btn = gr.Button("Remove images")
282
  restore_btn = gr.Button("Upscale")
283
 
284
- # Event listeners:
285
  input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
286
  model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips)
287
 
 
1
+ # ────────────────────────────────────────────────────────
2
+ # TorchVision compat shim (MUST be before importing basicsr)
3
+ # Fixes: ModuleNotFoundError: torchvision.transforms.functional_tensor
4
+ # ────────────────────────────────────────────────────────
5
+ import sys, types
6
+ try:
7
+ # If old path exists, do nothing
8
+ import torchvision.transforms.functional_tensor as _ft # noqa: F401
9
+ except Exception:
10
+ # Map to the new API location
11
+ from torchvision.transforms import functional as _F
12
+ _mod = types.ModuleType("torchvision.transforms.functional_tensor")
13
+ _mod.rgb_to_grayscale = _F.rgb_to_grayscale
14
+ sys.modules["torchvision.transforms.functional_tensor"] = _mod
15
+
16
+ # ────────────────────────────────────────────────────────
17
+ # Standard imports
18
+ # ────────────────────────────────────────────────────────
19
  import gradio as gr
20
  import cv2
21
  import numpy
22
  import os
23
  import random
24
+
25
  from basicsr.archs.rrdbnet_arch import RRDBNet
26
  from basicsr.utils.download_util import load_file_from_url
27
 
 
39
  # Utilities
40
  # ────────────────────────────────────────────────────────
41
  def rnd_string(x: int) -> str:
 
42
  characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
43
+ return "".join((random.choice(characters)) for _ in range(x))
 
44
 
45
 
46
  def reset():
 
47
  global last_file
48
  if last_file:
49
  try:
 
56
 
57
 
58
  def has_transparency(img):
 
 
 
 
59
  if img.info.get("transparency", None) is not None:
60
  return True
61
  if img.mode == "P":
 
71
 
72
 
73
  def image_properties(img):
 
74
  global img_mode
75
  if img:
76
+ img_mode = "RGBA" if has_transparency(img) else "RGB"
77
+ return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
 
 
 
 
78
 
79
 
80
  def model_tip_text(model_name: str) -> str:
 
81
  tips = {
82
  "RealESRGAN_x4plus": (
83
  "**RealESRGAN_x4plus (4Γ—)** β€” Best for photoreal images (portraits, landscapes). "
 
107
  # Core upscaling
108
  # ────────────────────────────────────────────────────────
109
  def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
 
110
  if img is None:
111
  return
112
 
113
  # ----- Select backbone + weights -----
114
+ if model_name == 'RealESRGAN_x4plus':
115
+ model = RRDBNet(3, 3, 64, 23, 32, scale=4); netscale = 4
 
116
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
117
 
118
+ elif model_name == 'RealESRNet_x4plus':
119
+ model = RRDBNet(3, 3, 64, 23, 32, scale=4); netscale = 4
 
120
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
121
 
122
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
123
+ model = RRDBNet(3, 3, 64, 6, 32, scale=4); netscale = 4
 
124
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
125
 
126
+ elif model_name == 'RealESRGAN_x2plus':
127
+ model = RRDBNet(3, 3, 64, 23, 32, scale=2); netscale = 2
 
128
  file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
129
 
130
+ elif model_name == 'realesr-general-x4v3':
131
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'); netscale = 4
 
 
132
  file_url = [
133
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
134
  'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
135
  ]
 
136
  else:
137
  raise ValueError(f"Unknown model: {model_name}")
138
 
139
+ # ----- Ensure weights on disk -----
 
140
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
141
  weights_dir = os.path.join(ROOT_DIR, 'weights')
142
  os.makedirs(weights_dir, exist_ok=True)
143
 
 
144
  local_paths = []
145
  for url in file_url:
146
  fname = os.path.basename(url)
 
149
  local_path = load_file_from_url(url=url, model_dir=weights_dir, progress=True)
150
  local_paths.append(local_path)
151
 
 
152
  if model_name == 'realesr-general-x4v3':
 
153
  base_path = os.path.join(weights_dir, 'realesr-general-x4v3.pth')
154
+ wdn_path = os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth')
155
  model_path = [base_path, wdn_path]
156
  denoise_strength = float(denoise_strength)
157
+ dni_weight = [1.0 - denoise_strength, denoise_strength] # base, WDN
 
158
  else:
159
  model_path = os.path.join(weights_dir, f"{model_name}.pth")
160
  dni_weight = None
161
 
162
  # ----- CUDA / precision / tiling -----
 
163
  use_cuda = False
164
  try:
165
  use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
166
  except Exception:
167
  use_cuda = False
 
168
  gpu_id = 0 if use_cuda else None
169
 
170
  upsampler = RealESRGANer(
 
172
  model_path=model_path,
173
  dni_weight=dni_weight,
174
  model=model,
175
+ tile=256,
176
  tile_pad=10,
177
  pre_pad=10,
178
+ half=bool(use_cuda),
179
  gpu_id=gpu_id
180
  )
181
 
 
191
  bg_upsampler=upsampler
192
  )
193
 
194
+ # ----- PIL -> cv2 -----
195
  cv_img = numpy.array(img)
196
  if cv_img.ndim == 3 and cv_img.shape[2] == 4:
197
  cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
 
209
  print('Tip: If you hit CUDA OOM, try a smaller tile size (e.g., 128).')
210
  return None
211
 
212
+ # ----- cv2 -> display ndarray, also save -----
213
  if output.ndim == 3 and output.shape[2] == 4:
214
  display_img = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
215
  extension = 'png'
 
225
  except Exception as e:
226
  print("Save error:", e)
227
 
228
+ return display_img
229
 
230
 
231
  # ────────────────────────────────────────────────────────
 
246
  "RealESRGAN_x2plus",
247
  "realesr-general-x4v3",
248
  ],
249
+ value="RealESRGAN_x4plus",
250
  show_label=True
251
  )
252
  denoise_strength = gr.Slider(
 
259
  )
260
  face_enhance = gr.Checkbox(label="Face Enhancement (GFPGAN)", value=False)
261
 
 
262
  model_tips = gr.Markdown(model_tip_text("RealESRGAN_x4plus"))
263
 
264
  with gr.Row():
 
271
  reset_btn = gr.Button("Remove images")
272
  restore_btn = gr.Button("Upscale")
273
 
 
274
  input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
275
  model_name.change(fn=model_tip_text, inputs=model_name, outputs=model_tips)
276