Jonny001 commited on
Commit
75bd11c
·
verified ·
1 Parent(s): 2346975

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -59
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import cv2
4
  import gradio as gr
5
  import torch
@@ -8,44 +7,41 @@ from gfpgan.utils import GFPGANer
8
  from realesrgan.utils import RealESRGANer
9
  import spaces
10
 
11
- os.system("pip freeze")
12
- # download weights
13
- if not os.path.exists('realesr-general-x4v3.pth'):
14
- os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
15
- if not os.path.exists('GFPGANv1.2.pth'):
16
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
17
- if not os.path.exists('GFPGANv1.3.pth'):
18
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
19
- if not os.path.exists('GFPGANv1.4.pth'):
20
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
21
- if not os.path.exists('RestoreFormer.pth'):
22
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
23
- if not os.path.exists('CodeFormer.pth'):
24
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
25
-
26
-
27
- # background enhancer with RealESRGAN
28
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
29
  model_path = 'realesr-general-x4v3.pth'
30
- half = True if torch.cuda.is_available() else False
31
  upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
32
 
33
  os.makedirs('output', exist_ok=True)
34
 
35
-
36
- # def inference(img, version, scale, weight):
37
  @spaces.GPU(enable_queue=True)
38
  def inference(img, version, scale):
39
- # weight /= 100
40
  print(img, version, scale)
41
  if scale > 4:
42
- scale = 4 # avoid too large scale value
 
43
  try:
44
  extension = os.path.splitext(os.path.basename(str(img)))[1]
45
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
 
46
  if len(img.shape) == 3 and img.shape[2] == 4:
47
  img_mode = 'RGBA'
48
- elif len(img.shape) == 2: # for gray inputs
49
  img_mode = None
50
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
51
  else:
@@ -53,70 +49,62 @@ def inference(img, version, scale):
53
 
54
  h, w = img.shape[0:2]
55
  if h > 3500 or w > 3500:
56
- print('too large size')
57
  return None, None
58
-
59
  if h < 300:
60
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
61
 
 
62
  if version == 'v1.2':
63
- face_enhancer = GFPGANer(
64
- model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
65
  elif version == 'v1.3':
66
- face_enhancer = GFPGANer(
67
- model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
68
  elif version == 'v1.4':
69
- face_enhancer = GFPGANer(
70
- model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
71
  elif version == 'RestoreFormer':
72
- face_enhancer = GFPGANer(
73
- model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
74
- # elif version == 'CodeFormer':
75
- # face_enhancer = GFPGANer(
76
- # model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
77
 
78
  try:
79
- # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
80
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
81
  except RuntimeError as error:
82
- print('Error', error)
 
83
 
84
  try:
85
  if scale != 2:
86
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
87
- h, w = img.shape[0:2]
88
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
89
  except Exception as error:
90
- print('wrong scale input.', error)
91
- if img_mode == 'RGBA': # RGBA images should be saved in png format
92
- extension = 'png'
93
- else:
94
- extension = 'jpg'
95
  save_path = f'output/out.{extension}'
96
  cv2.imwrite(save_path, output)
97
 
98
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
99
  return output, save_path
 
100
  except Exception as error:
101
- print('global exception', error)
102
  return None, None
103
 
 
 
104
 
105
- description = "⚠ Sorry for the inconvenience. The Space is currently running on the CPU, which might affect performance. We appreciate your understanding."
106
-
107
  demo = gr.Interface(
108
- inference, [
 
109
  gr.Image(type="filepath", label="Input"),
110
- # gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'], type="value", value='v1.4', label='version'),
111
- gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='version'),
112
  gr.Number(label="Rescaling factor", value=2),
113
- # gr.Slider(0, 100, label='Weight, only for CodeFormer. 0 for better quality, 100 for better identity', value=50)
114
- ], [
115
- gr.Image(type="numpy", label="Output (The whole image)"),
116
- gr.File(label="Download the output image")
117
  ],
118
  description=description,
119
- # examples=[['AI-generate.jpg', 'v1.4', 2, 50], ['lincoln.jpg', 'v1.4', 2, 50], ['Blake_Lively.jpg', 'v1.4', 2, 50],
120
- # ['10045.png', 'v1.4', 2, 50]]).launch()
121
-
122
- demo.queue(max_size=50).launch()
 
1
  import os
 
2
  import cv2
3
  import gradio as gr
4
  import torch
 
7
  from realesrgan.utils import RealESRGANer
8
  import spaces
9
 
10
+ # Download model weights if not already present
11
+ weights = {
12
+ 'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
13
+ 'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
14
+ 'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
15
+ 'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
16
+ 'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
17
+ 'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth'
18
+ }
19
+
20
+ for filename, url in weights.items():
21
+ if not os.path.exists(filename):
22
+ os.system(f"wget {url} -P .")
23
+
24
+ # Initialize background enhancer (RealESRGAN)
 
 
25
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
26
  model_path = 'realesr-general-x4v3.pth'
27
+ half = torch.cuda.is_available()
28
  upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
29
 
30
  os.makedirs('output', exist_ok=True)
31
 
 
 
32
  @spaces.GPU(enable_queue=True)
33
  def inference(img, version, scale):
 
34
  print(img, version, scale)
35
  if scale > 4:
36
+ scale = 4
37
+
38
  try:
39
  extension = os.path.splitext(os.path.basename(str(img)))[1]
40
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
41
+
42
  if len(img.shape) == 3 and img.shape[2] == 4:
43
  img_mode = 'RGBA'
44
+ elif len(img.shape) == 2:
45
  img_mode = None
46
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
47
  else:
 
49
 
50
  h, w = img.shape[0:2]
51
  if h > 3500 or w > 3500:
52
+ print('Too large size')
53
  return None, None
54
+
55
  if h < 300:
56
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
57
 
58
+ # Load selected model
59
  if version == 'v1.2':
60
+ face_enhancer = GFPGANer(model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
61
  elif version == 'v1.3':
62
+ face_enhancer = GFPGANer(model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
63
  elif version == 'v1.4':
64
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
65
  elif version == 'RestoreFormer':
66
+ face_enhancer = GFPGANer(model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
 
 
 
 
67
 
68
  try:
 
69
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
70
  except RuntimeError as error:
71
+ print('Enhancement error:', error)
72
+ return None, None
73
 
74
  try:
75
  if scale != 2:
76
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
77
+ h, w = img.shape[:2]
78
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
79
  except Exception as error:
80
+ print('Rescale error:', error)
81
+
82
+ extension = 'png' if img_mode == 'RGBA' else 'jpg'
 
 
83
  save_path = f'output/out.{extension}'
84
  cv2.imwrite(save_path, output)
85
 
86
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
87
  return output, save_path
88
+
89
  except Exception as error:
90
+ print('Global exception:', error)
91
  return None, None
92
 
93
+ # UI Description
94
+ description = "⚠ Currently running on CPU — expect slower performance. Thank you for your patience."
95
 
96
+ # Gradio Interface
 
97
  demo = gr.Interface(
98
+ fn=inference,
99
+ inputs=[
100
  gr.Image(type="filepath", label="Input"),
101
+ gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='Model Version'),
 
102
  gr.Number(label="Rescaling factor", value=2),
103
+ ],
104
+ outputs=[
105
+ gr.Image(type="numpy", label="Enhanced Output"),
106
+ gr.File(label="Download Enhanced Image")
107
  ],
108
  description=description,
109
+ # examples=[['AI-generate.jpg', 'v1.4', 2], ['lincoln.jpg', 'v1.4', 2], ['Blake_Lively.jpg', 'v1.4', 2], ['10045.png', 'v1.4', 2]]
110
+ ).queue(max_size=50).launch()