Jonny001 commited on
Commit
2d2df45
·
verified ·
1 Parent(s): e0230fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -46
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import cv2
3
  import gradio as gr
4
  import torch
@@ -7,38 +8,40 @@ from gfpgan.utils import GFPGANer
7
  from realesrgan.utils import RealESRGANer
8
  import spaces
9
 
10
- # Download model weights if not already present
11
- weights = {
12
- 'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
13
- 'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
14
- 'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
15
- 'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
16
- 'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
17
- 'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth'
18
- }
19
-
20
- for filename, url in weights.items():
21
- if not os.path.exists(filename):
22
- os.system(f"wget {url} -P .")
23
-
24
- # Initialize background enhancer (RealESRGAN)
 
 
25
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
26
  model_path = 'realesr-general-x4v3.pth'
27
- half = torch.cuda.is_available()
28
  upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
29
 
30
  os.makedirs('output', exist_ok=True)
31
 
 
 
32
  @spaces.GPU(enable_queue=True)
33
  def inference(img, version, scale):
34
  print(img, version, scale)
35
  if scale > 4:
36
  scale = 4
37
-
38
  try:
39
  extension = os.path.splitext(os.path.basename(str(img)))[1]
40
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
41
-
42
  if len(img.shape) == 3 and img.shape[2] == 4:
43
  img_mode = 'RGBA'
44
  elif len(img.shape) == 2:
@@ -49,62 +52,63 @@ def inference(img, version, scale):
49
 
50
  h, w = img.shape[0:2]
51
  if h > 3500 or w > 3500:
52
- print('Too large size')
53
  return None, None
54
-
55
  if h < 300:
56
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
57
 
58
- # Load selected model
59
  if version == 'v1.2':
60
- face_enhancer = GFPGANer(model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
61
  elif version == 'v1.3':
62
- face_enhancer = GFPGANer(model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
63
  elif version == 'v1.4':
64
- face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
65
  elif version == 'RestoreFormer':
66
- face_enhancer = GFPGANer(model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
67
-
68
  try:
69
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
70
  except RuntimeError as error:
71
- print('Enhancement error:', error)
72
- return None, None
73
 
74
  try:
75
  if scale != 2:
76
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
77
- h, w = img.shape[:2]
78
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
79
  except Exception as error:
80
- print('Rescale error:', error)
81
-
82
- extension = 'png' if img_mode == 'RGBA' else 'jpg'
 
 
83
  save_path = f'output/out.{extension}'
84
  cv2.imwrite(save_path, output)
85
 
86
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
87
  return output, save_path
88
-
89
  except Exception as error:
90
- print('Global exception:', error)
91
  return None, None
92
 
93
- # UI Description
94
- description = "⚠ Currently running on CPU — expect slower performance. Thank you for your patience."
95
 
96
- # Gradio Interface
 
97
  demo = gr.Interface(
98
- fn=inference,
99
- inputs=[
100
  gr.Image(type="filepath", label="Input"),
101
- gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='Model Version'),
102
  gr.Number(label="Rescaling factor", value=2),
103
- ],
104
- outputs=[
105
- gr.Image(type="numpy", label="Enhanced Output"),
106
- gr.File(label="Download Enhanced Image")
107
  ],
108
  description=description,
109
-
110
- ).queue(max_size=50).launch()
 
 
1
  import os
2
+
3
  import cv2
4
  import gradio as gr
5
  import torch
 
8
  from realesrgan.utils import RealESRGANer
9
  import spaces
10
 
11
+ os.system("pip freeze")
12
+ # download weights
13
+ if not os.path.exists('realesr-general-x4v3.pth'):
14
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
15
+ if not os.path.exists('GFPGANv1.2.pth'):
16
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
17
+ if not os.path.exists('GFPGANv1.3.pth'):
18
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
19
+ if not os.path.exists('GFPGANv1.4.pth'):
20
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
21
+ if not os.path.exists('RestoreFormer.pth'):
22
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
23
+ if not os.path.exists('CodeFormer.pth'):
24
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
25
+
26
+
27
+ # background enhancer with RealESRGAN
28
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
29
  model_path = 'realesr-general-x4v3.pth'
30
+ half = True if torch.cuda.is_available() else False
31
  upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
32
 
33
  os.makedirs('output', exist_ok=True)
34
 
35
+
36
+ # def inference(img, version, scale, weight):
37
  @spaces.GPU(enable_queue=True)
38
  def inference(img, version, scale):
39
  print(img, version, scale)
40
  if scale > 4:
41
  scale = 4
 
42
  try:
43
  extension = os.path.splitext(os.path.basename(str(img)))[1]
44
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
 
45
  if len(img.shape) == 3 and img.shape[2] == 4:
46
  img_mode = 'RGBA'
47
  elif len(img.shape) == 2:
 
52
 
53
  h, w = img.shape[0:2]
54
  if h > 3500 or w > 3500:
55
+ print('too large size')
56
  return None, None
57
+
58
  if h < 300:
59
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
60
 
 
61
  if version == 'v1.2':
62
+ face_enhancer = GFPGANer(
63
+ model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
64
  elif version == 'v1.3':
65
+ face_enhancer = GFPGANer(
66
+ model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
67
  elif version == 'v1.4':
68
+ face_enhancer = GFPGANer(
69
+ model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
70
  elif version == 'RestoreFormer':
71
+ face_enhancer = GFPGANer(
72
+ model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
73
  try:
74
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
75
  except RuntimeError as error:
76
+ print('Error', error)
 
77
 
78
  try:
79
  if scale != 2:
80
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
81
+ h, w = img.shape[0:2]
82
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
83
  except Exception as error:
84
+ print('wrong scale input.', error)
85
+ if img_mode == 'RGBA':
86
+ extension = 'png'
87
+ else:
88
+ extension = 'jpg'
89
  save_path = f'output/out.{extension}'
90
  cv2.imwrite(save_path, output)
91
 
92
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
93
  return output, save_path
 
94
  except Exception as error:
95
+ print('global exception', error)
96
  return None, None
97
 
 
 
98
 
99
+ description = "⚠ Sorry for the inconvenience. The Space is currently running on the CPU, which might affect performance. We appreciate your understanding."
100
+
101
  demo = gr.Interface(
102
+ inference, [
 
103
  gr.Image(type="filepath", label="Input"),
104
+ gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='version'),
105
  gr.Number(label="Rescaling factor", value=2),
106
+
107
+ ], [
108
+ gr.Image(type="numpy", label="Output (The whole image)"),
109
+ gr.File(label="Download the output image")
110
  ],
111
  description=description,
112
+
113
+
114
+ demo.queue(max_size=50).launch()