dotmet commited on
Commit
72fcf3a
·
1 Parent(s): 3ec7bf9

Upload interface.py

Browse files
Files changed (1) hide show
  1. interface.py +140 -0
interface.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import glob
4
+ import os
5
+ from basicsr.archs.rrdbnet_arch import RRDBNet
6
+ from basicsr.utils.download_util import load_file_from_url
7
+
8
+ from realesrgan import RealESRGANer
9
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
10
+
11
+ def realEsrgan(model_name="RealESRGAN_x4plus_anime_6B",
12
+ model_path = None,
13
+ input_dir = 'inputs',
14
+ output_dir = 'results',
15
+ denoise_strength = 0.5,
16
+ outscale = 4,
17
+ suffix = 'out',
18
+ tile = 200,
19
+ tile_pad = 10,
20
+ pre_pad = 0,
21
+ face_enhance = True,
22
+ alpha_upsampler = 'realsrgan',
23
+ out_ext = 'auto',
24
+ fp32 = True,
25
+ gpu_id = None,
26
+ ):
27
+
28
+ # determine models according to model names
29
+ model_name = model_name.split('.')[0]
30
+ if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
31
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
32
+ netscale = 4
33
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
34
+ elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
35
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
36
+ netscale = 4
37
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
38
+ elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
39
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
40
+ netscale = 4
41
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
42
+ elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
43
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
44
+ netscale = 2
45
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
46
+ elif model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
47
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
48
+ netscale = 4
49
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
50
+ elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
51
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
52
+ netscale = 4
53
+ file_url = [
54
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
55
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
56
+ ]
57
+
58
+ # determine model paths
59
+ if model_path is None:
60
+ model_path = os.path.join('weights', model_name + '.pth')
61
+ if not os.path.isfile(model_path):
62
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
63
+ for url in file_url:
64
+ # model_path will be updated
65
+ model_path = load_file_from_url(
66
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
67
+
68
+ # use dni to control the denoise strength
69
+ dni_weight = None
70
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
71
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
72
+ model_path = [model_path, wdn_model_path]
73
+ dni_weight = [denoise_strength, 1 - denoise_strength]
74
+
75
+ # restorer
76
+ upsampler = RealESRGANer(
77
+ scale=netscale,
78
+ model_path=model_path,
79
+ dni_weight=dni_weight,
80
+ model=model,
81
+ tile=tile,
82
+ tile_pad=tile_pad,
83
+ pre_pad=pre_pad,
84
+ half=not fp32,
85
+ gpu_id=gpu_id)
86
+
87
+ if face_enhance: # Use GFPGAN for face enhancement
88
+ from gfpgan import GFPGANer
89
+ face_enhancer = GFPGANer(
90
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
91
+ upscale=outscale,
92
+ arch='clean',
93
+ channel_multiplier=2,
94
+ bg_upsampler=upsampler)
95
+ os.makedirs(output_dir, exist_ok=True)
96
+
97
+ if os.path.isfile(input_dir):
98
+ paths = [input_dir]
99
+ else:
100
+ paths = sorted(glob.glob(os.path.join(input_dir, '*')))
101
+
102
+ Imgs = []
103
+ for idx, path in enumerate(paths):
104
+ imgname, extension = os.path.splitext(os.path.basename(path))
105
+ print(f'Scaling x{outscale}:', path)
106
+
107
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
108
+ if len(img.shape) == 3 and img.shape[2] == 4:
109
+ img_mode = 'RGBA'
110
+ else:
111
+ img_mode = None
112
+
113
+ try:
114
+ if face_enhance:
115
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
116
+ else:
117
+ output, _ = upsampler.enhance(img, outscale=outscale)
118
+ except RuntimeError as error:
119
+ print('Error', error)
120
+ print('If you encounter CUDA or RAM out of memory, try to set --tile with a smaller number.')
121
+ else:
122
+ if out_ext == 'auto':
123
+ extension = extension[1:]
124
+ else:
125
+ extension = out_ext
126
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
127
+ extension = 'png'
128
+ if suffix == '':
129
+ save_path = os.path.join(output_dir, f'{imgname}.{extension}')
130
+ else:
131
+ save_path = os.path.join(output_dir, f'{imgname}_{suffix}.{extension}')
132
+
133
+ cv2.imwrite(save_path, output)
134
+
135
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
136
+ img = Image.fromarray(img)
137
+ Imgs.append(img)
138
+
139
+ return Imgs
140
+