Jonny001 commited on
Commit
3b94f47
·
verified ·
1 Parent(s): 51cb097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -80
app.py CHANGED
@@ -5,126 +5,107 @@ import torch
5
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
6
  from gfpgan.utils import GFPGANer
7
  from realesrgan.utils import RealESRGANer
8
- import spaces
9
 
10
- # Ensure weights directory
11
- WEIGHTS_DIR = "weights"
12
- os.makedirs(WEIGHTS_DIR, exist_ok=True)
13
-
14
- # List of required model weights
15
- weights = {
16
- 'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
17
- 'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
18
- 'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
19
- 'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
20
- 'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
21
- 'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth'
22
- }
23
-
24
- # Download missing weights
25
- for file_name, url in weights.items():
26
- file_path = os.path.join(WEIGHTS_DIR, file_name)
27
- if not os.path.exists(file_path):
28
- os.system(f"wget -O {file_path} {url}")
29
-
30
- # Load Real-ESRGAN model
31
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
32
- model_path = os.path.join(WEIGHTS_DIR, 'realesr-general-x4v3.pth')
33
- half = torch.cuda.is_available()
34
-
35
- upsampler = RealESRGANer(
36
- scale=4,
37
- model_path=model_path,
38
- model=model,
39
- tile=0,
40
- tile_pad=10,
41
- pre_pad=0,
42
- half=half
43
- )
44
 
45
  os.makedirs('output', exist_ok=True)
46
 
47
- @spaces.GPU(enable_queue=True)
 
48
  def inference(img, version, scale):
49
  print(img, version, scale)
50
  if scale > 4:
51
  scale = 4
52
-
53
  try:
54
  extension = os.path.splitext(os.path.basename(str(img)))[1]
55
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
56
-
57
  if len(img.shape) == 3 and img.shape[2] == 4:
58
  img_mode = 'RGBA'
59
  elif len(img.shape) == 2:
60
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
61
  img_mode = None
 
62
  else:
63
  img_mode = None
64
 
65
- h, w = img.shape[:2]
66
  if h > 3500 or w > 3500:
67
  print('too large size')
68
  return None, None
69
-
70
  if h < 300:
71
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
72
 
73
- model_path_map = {
74
- 'v1.2': os.path.join(WEIGHTS_DIR, 'GFPGANv1.2.pth'),
75
- 'v1.3': os.path.join(WEIGHTS_DIR, 'GFPGANv1.3.pth'),
76
- 'v1.4': os.path.join(WEIGHTS_DIR, 'GFPGANv1.4.pth'),
77
- 'RestoreFormer': os.path.join(WEIGHTS_DIR, 'RestoreFormer.pth')
78
- }
79
-
80
- arch_map = {
81
- 'v1.2': 'clean',
82
- 'v1.3': 'clean',
83
- 'v1.4': 'clean',
84
- 'RestoreFormer': 'RestoreFormer'
85
- }
86
-
87
- face_enhancer = GFPGANer(
88
- model_path=model_path_map[version],
89
- upscale=2,
90
- arch=arch_map[version],
91
- channel_multiplier=2,
92
- bg_upsampler=upsampler
93
- )
94
-
95
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
96
-
97
- if scale != 2:
98
- interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
99
- h, w = img.shape[:2]
100
- output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
101
-
102
- extension = 'png' if img_mode == 'RGBA' else 'jpg'
103
  save_path = f'output/out.{extension}'
104
  cv2.imwrite(save_path, output)
105
 
106
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
107
  return output, save_path
108
-
109
  except Exception as error:
110
  print('global exception', error)
111
  return None, None
112
 
113
- description = "⚠ Sorry for the inconvenience. The Space is currently running on the CPU, which might affect performance. We appreciate your understanding."
 
114
 
115
  demo = gr.Interface(
116
- fn=inference,
117
- inputs=[
118
  gr.Image(type="filepath", label="Input"),
119
  gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='version'),
120
- gr.Number(label="Rescaling factor", value=2)
121
- ],
122
- outputs=[
123
  gr.Image(type="numpy", label="Output (The whole image)"),
124
  gr.File(label="Download the output image")
125
  ],
126
  description=description,
127
- theme="Yntec/HaleyCH_Theme_Orange"
128
- )
129
 
130
- demo.queue(max_size=50).launch()
 
 
5
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
6
  from gfpgan.utils import GFPGANer
7
  from realesrgan.utils import RealESRGANer
 
8
 
9
+ os.system("pip freeze")
10
+ # download weights
11
+ if not os.path.exists('realesr-general-x4v3.pth'):
12
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
13
+ if not os.path.exists('GFPGANv1.2.pth'):
14
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
15
+ if not os.path.exists('GFPGANv1.3.pth'):
16
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
17
+ if not os.path.exists('GFPGANv1.4.pth'):
18
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
19
+ if not os.path.exists('RestoreFormer.pth'):
20
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
21
+ if not os.path.exists('CodeFormer.pth'):
22
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
23
+
24
+
25
+ # background enhancer with RealESRGAN
 
 
 
 
26
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
27
+ model_path = 'realesr-general-x4v3.pth'
28
+ half = True if torch.cuda.is_available() else False
29
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
 
 
 
 
 
 
 
 
 
30
 
31
  os.makedirs('output', exist_ok=True)
32
 
33
+
34
+ # def inference(img, version, scale, weight):
35
  def inference(img, version, scale):
36
  print(img, version, scale)
37
  if scale > 4:
38
  scale = 4
 
39
  try:
40
  extension = os.path.splitext(os.path.basename(str(img)))[1]
41
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
 
42
  if len(img.shape) == 3 and img.shape[2] == 4:
43
  img_mode = 'RGBA'
44
  elif len(img.shape) == 2:
 
45
  img_mode = None
46
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
47
  else:
48
  img_mode = None
49
 
50
+ h, w = img.shape[0:2]
51
  if h > 3500 or w > 3500:
52
  print('too large size')
53
  return None, None
54
+
55
  if h < 300:
56
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
57
 
58
+ if version == 'v1.2':
59
+ face_enhancer = GFPGANer(
60
+ model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
61
+ elif version == 'v1.3':
62
+ face_enhancer = GFPGANer(
63
+ model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
64
+ elif version == 'v1.4':
65
+ face_enhancer = GFPGANer(
66
+ model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
67
+ elif version == 'RestoreFormer':
68
+ face_enhancer = GFPGANer(
69
+ model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
70
+
71
+ try:
72
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
73
+ except RuntimeError as error:
74
+ print('Error', error)
75
+
76
+ try:
77
+ if scale != 2:
78
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
79
+ h, w = img.shape[0:2]
80
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
81
+ except Exception as error:
82
+ print('wrong scale input.', error)
83
+ if img_mode == 'RGBA':
84
+ extension = 'png'
85
+ else:
86
+ extension = 'jpg'
 
87
  save_path = f'output/out.{extension}'
88
  cv2.imwrite(save_path, output)
89
 
90
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
91
  return output, save_path
 
92
  except Exception as error:
93
  print('global exception', error)
94
  return None, None
95
 
96
+
97
+ description = r""" """
98
 
99
  demo = gr.Interface(
100
+ inference, [
 
101
  gr.Image(type="filepath", label="Input"),
102
  gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], type="value", value='v1.4', label='version'),
103
+ gr.Number(label="Rescaling factor", value=2),
104
+ ], [
 
105
  gr.Image(type="numpy", label="Output (The whole image)"),
106
  gr.File(label="Download the output image")
107
  ],
108
  description=description,
 
 
109
 
110
+ )
111
+ demo.queue().launch()