dxcanh commited on
Commit
3e4e5ab
Β·
verified Β·
1 Parent(s): 6c5f92e

Upload 2 files

Browse files
Real-ESRGAN/inference_realesrgan.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import os
5
+ from basicsr.archs.rrdbnet_arch import RRDBNet
6
+ from basicsr.utils.download_util import load_file_from_url
7
+
8
+ from realesrgan import RealESRGANer
9
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
10
+
11
+
12
+ def main():
13
+ """Inference demo for Real-ESRGAN.
14
+ """
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
17
+ parser.add_argument(
18
+ '-n',
19
+ '--model_name',
20
+ type=str,
21
+ default='RealESRGAN_x4plus',
22
+ help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | '
23
+ 'realesr-animevideov3 | realesr-general-x4v3'))
24
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
25
+ parser.add_argument(
26
+ '-dn',
27
+ '--denoise_strength',
28
+ type=float,
29
+ default=0.5,
30
+ help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
31
+ 'Only used for the realesr-general-x4v3 model'))
32
+ parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
33
+ parser.add_argument(
34
+ '--model_path', type=str, default=None, help='[Option] Model path. Usually, you do not need to specify it')
35
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
36
+ parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
37
+ parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
38
+ parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
39
+ parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
40
+ parser.add_argument(
41
+ '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
42
+ parser.add_argument(
43
+ '--alpha_upsampler',
44
+ type=str,
45
+ default='realesrgan',
46
+ help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
47
+ parser.add_argument(
48
+ '--ext',
49
+ type=str,
50
+ default='auto',
51
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
52
+ parser.add_argument(
53
+ '-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
54
+
55
+ args = parser.parse_args()
56
+
57
+ # determine models according to model names
58
+ args.model_name = args.model_name.split('.')[0]
59
+ if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
60
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
61
+ netscale = 4
62
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
63
+ elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
64
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
65
+ netscale = 4
66
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
67
+ elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
68
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
69
+ netscale = 4
70
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
71
+ elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
72
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
73
+ netscale = 2
74
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
75
+ elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
76
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
77
+ netscale = 4
78
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
79
+ elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
80
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
81
+ netscale = 4
82
+ file_url = [
83
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
84
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
85
+ ]
86
+
87
+ # determine model paths
88
+ if args.model_path is not None:
89
+ model_path = args.model_path
90
+ else:
91
+ model_path = os.path.join('weights', args.model_name + '.pth')
92
+ if not os.path.isfile(model_path):
93
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
94
+ for url in file_url:
95
+ # model_path will be updated
96
+ model_path = load_file_from_url(
97
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
98
+
99
+ # use dni to control the denoise strength
100
+ dni_weight = None
101
+ if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
102
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
103
+ model_path = [model_path, wdn_model_path]
104
+ dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
105
+
106
+ # restorer
107
+ upsampler = RealESRGANer(
108
+ scale=netscale,
109
+ model_path=model_path,
110
+ dni_weight=dni_weight,
111
+ model=model,
112
+ tile=args.tile,
113
+ tile_pad=args.tile_pad,
114
+ pre_pad=args.pre_pad,
115
+ half=not args.fp32,
116
+ gpu_id=args.gpu_id)
117
+
118
+ if args.face_enhance: # Use GFPGAN for face enhancement
119
+ from gfpgan import GFPGANer
120
+ face_enhancer = GFPGANer(
121
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth',
122
+ upscale=args.outscale,
123
+ arch='clean',
124
+ channel_multiplier=2,
125
+ bg_upsampler=upsampler)
126
+ os.makedirs(args.output, exist_ok=True)
127
+
128
+ if os.path.isfile(args.input):
129
+ paths = [args.input]
130
+ else:
131
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
132
+
133
+ for idx, path in enumerate(paths):
134
+ imgname, extension = os.path.splitext(os.path.basename(path))
135
+ print('Testing', idx, imgname)
136
+
137
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
138
+ if len(img.shape) == 3 and img.shape[2] == 4:
139
+ img_mode = 'RGBA'
140
+ else:
141
+ img_mode = None
142
+
143
+ try:
144
+ if args.face_enhance:
145
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
146
+ else:
147
+ output, _ = upsampler.enhance(img, outscale=args.outscale)
148
+ except RuntimeError as error:
149
+ print('Error', error)
150
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
151
+ else:
152
+ if args.ext == 'auto':
153
+ extension = extension[1:]
154
+ else:
155
+ extension = args.ext
156
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
157
+ extension = 'png'
158
+ if args.suffix == '':
159
+ save_path = os.path.join(args.output, f'{imgname}.{extension}')
160
+ else:
161
+ save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
162
+ cv2.imwrite(save_path, output)
163
+
164
+
165
+ if __name__ == '__main__':
166
+ main()
Real-ESRGAN/inference_realesrgan_video.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import mimetypes
5
+ import numpy as np
6
+ import os
7
+ import shutil
8
+ import subprocess
9
+ import torch
10
+ from basicsr.archs.rrdbnet_arch import RRDBNet
11
+ from basicsr.utils.download_util import load_file_from_url
12
+ from os import path as osp
13
+ from tqdm import tqdm
14
+
15
+ from realesrgan import RealESRGANer
16
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
17
+
18
+ try:
19
+ import ffmpeg
20
+ except ImportError:
21
+ import pip
22
+ pip.main(['install', '--user', 'ffmpeg-python'])
23
+ import ffmpeg
24
+
25
+
26
+ def get_video_meta_info(video_path):
27
+ ret = {}
28
+ probe = ffmpeg.probe(video_path)
29
+ video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
30
+ has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
31
+ ret['width'] = video_streams[0]['width']
32
+ ret['height'] = video_streams[0]['height']
33
+ ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
34
+ ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
35
+ ret['nb_frames'] = int(video_streams[0]['nb_frames'])
36
+ return ret
37
+
38
+
39
+ def get_sub_video(args, num_process, process_idx):
40
+ if num_process == 1:
41
+ return args.input
42
+ meta = get_video_meta_info(args.input)
43
+ duration = int(meta['nb_frames'] / meta['fps'])
44
+ part_time = duration // num_process
45
+ print(f'duration: {duration}, part_time: {part_time}')
46
+ os.makedirs(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'), exist_ok=True)
47
+ out_path = osp.join(args.output, f'{args.video_name}_inp_tmp_videos', f'{process_idx:03d}.mp4')
48
+ cmd = [
49
+ args.ffmpeg_bin, f'-i {args.input}', '-ss', f'{part_time * process_idx}',
50
+ f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y'
51
+ ]
52
+ print(' '.join(cmd))
53
+ subprocess.call(' '.join(cmd), shell=True)
54
+ return out_path
55
+
56
+
57
+ class Reader:
58
+
59
+ def __init__(self, args, total_workers=1, worker_idx=0):
60
+ self.args = args
61
+ input_type = mimetypes.guess_type(args.input)[0]
62
+ self.input_type = 'folder' if input_type is None else input_type
63
+ self.paths = [] # for image&folder type
64
+ self.audio = None
65
+ self.input_fps = None
66
+ if self.input_type.startswith('video'):
67
+ video_path = get_sub_video(args, total_workers, worker_idx)
68
+ self.stream_reader = (
69
+ ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
70
+ loglevel='error').run_async(
71
+ pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
72
+ meta = get_video_meta_info(video_path)
73
+ self.width = meta['width']
74
+ self.height = meta['height']
75
+ self.input_fps = meta['fps']
76
+ self.audio = meta['audio']
77
+ self.nb_frames = meta['nb_frames']
78
+
79
+ else:
80
+ if self.input_type.startswith('image'):
81
+ self.paths = [args.input]
82
+ else:
83
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
84
+ tot_frames = len(paths)
85
+ num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)
86
+ self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]
87
+
88
+ self.nb_frames = len(self.paths)
89
+ assert self.nb_frames > 0, 'empty folder'
90
+ from PIL import Image
91
+ tmp_img = Image.open(self.paths[0])
92
+ self.width, self.height = tmp_img.size
93
+ self.idx = 0
94
+
95
+ def get_resolution(self):
96
+ return self.height, self.width
97
+
98
+ def get_fps(self):
99
+ if self.args.fps is not None:
100
+ return self.args.fps
101
+ elif self.input_fps is not None:
102
+ return self.input_fps
103
+ return 24
104
+
105
+ def get_audio(self):
106
+ return self.audio
107
+
108
+ def __len__(self):
109
+ return self.nb_frames
110
+
111
+ def get_frame_from_stream(self):
112
+ img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
113
+ if not img_bytes:
114
+ return None
115
+ img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
116
+ return img
117
+
118
+ def get_frame_from_list(self):
119
+ if self.idx >= self.nb_frames:
120
+ return None
121
+ img = cv2.imread(self.paths[self.idx])
122
+ self.idx += 1
123
+ return img
124
+
125
+ def get_frame(self):
126
+ if self.input_type.startswith('video'):
127
+ return self.get_frame_from_stream()
128
+ else:
129
+ return self.get_frame_from_list()
130
+
131
+ def close(self):
132
+ if self.input_type.startswith('video'):
133
+ self.stream_reader.stdin.close()
134
+ self.stream_reader.wait()
135
+
136
+
137
+ class Writer:
138
+
139
+ def __init__(self, args, audio, height, width, video_save_path, fps):
140
+ out_width, out_height = int(width * args.outscale), int(height * args.outscale)
141
+ if out_height > 2160:
142
+ print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
143
+ 'We highly recommend to decrease the outscale(aka, -s).')
144
+
145
+ if audio is not None:
146
+ self.stream_writer = (
147
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
148
+ framerate=fps).output(
149
+ audio,
150
+ video_save_path,
151
+ pix_fmt='yuv420p',
152
+ vcodec='libx264',
153
+ loglevel='error',
154
+ acodec='copy').overwrite_output().run_async(
155
+ pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
156
+ else:
157
+ self.stream_writer = (
158
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
159
+ framerate=fps).output(
160
+ video_save_path, pix_fmt='yuv420p', vcodec='libx264',
161
+ loglevel='error').overwrite_output().run_async(
162
+ pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
163
+
164
+ def write_frame(self, frame):
165
+ frame = frame.astype(np.uint8).tobytes()
166
+ self.stream_writer.stdin.write(frame)
167
+
168
+ def close(self):
169
+ self.stream_writer.stdin.close()
170
+ self.stream_writer.wait()
171
+
172
+
173
+ def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):
174
+ # ---------------------- determine models according to model names ---------------------- #
175
+ args.model_name = args.model_name.split('.pth')[0]
176
+ if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
177
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
178
+ netscale = 4
179
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
180
+ elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
181
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
182
+ netscale = 4
183
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
184
+ elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
185
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
186
+ netscale = 4
187
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
188
+ elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
189
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
190
+ netscale = 2
191
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
192
+ elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
193
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
194
+ netscale = 4
195
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
196
+ elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
197
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
198
+ netscale = 4
199
+ file_url = [
200
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
201
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
202
+ ]
203
+
204
+ # ---------------------- determine model paths ---------------------- #
205
+ model_path = os.path.join('weights', args.model_name + '.pth')
206
+ if not os.path.isfile(model_path):
207
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
208
+ for url in file_url:
209
+ # model_path will be updated
210
+ model_path = load_file_from_url(
211
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
212
+
213
+ # use dni to control the denoise strength
214
+ dni_weight = None
215
+ if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
216
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
217
+ model_path = [model_path, wdn_model_path]
218
+ dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
219
+
220
+ # restorer
221
+ upsampler = RealESRGANer(
222
+ scale=netscale,
223
+ model_path=model_path,
224
+ dni_weight=dni_weight,
225
+ model=model,
226
+ tile=args.tile,
227
+ tile_pad=args.tile_pad,
228
+ pre_pad=args.pre_pad,
229
+ half=not args.fp32,
230
+ device=device,
231
+ )
232
+
233
+ if 'anime' in args.model_name and args.face_enhance:
234
+ print('face_enhance is not supported in anime models, we turned this option off for you. '
235
+ 'if you insist on turning it on, please manually comment the relevant lines of code.')
236
+ args.face_enhance = False
237
+
238
+ if args.face_enhance: # Use GFPGAN for face enhancement
239
+ from gfpgan import GFPGANer
240
+ face_enhancer = GFPGANer(
241
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth',
242
+ upscale=args.outscale,
243
+ arch='clean',
244
+ channel_multiplier=2,
245
+ bg_upsampler=upsampler) # TODO support custom device
246
+ else:
247
+ face_enhancer = None
248
+
249
+ reader = Reader(args, total_workers, worker_idx)
250
+ audio = reader.get_audio()
251
+ height, width = reader.get_resolution()
252
+ fps = reader.get_fps()
253
+ writer = Writer(args, audio, height, width, video_save_path, fps)
254
+
255
+ pbar = tqdm(total=len(reader), unit='frame', desc='inference')
256
+ while True:
257
+ img = reader.get_frame()
258
+ if img is None:
259
+ break
260
+
261
+ try:
262
+ if args.face_enhance:
263
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
264
+ else:
265
+ output, _ = upsampler.enhance(img, outscale=args.outscale)
266
+ except RuntimeError as error:
267
+ print('Error', error)
268
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
269
+ else:
270
+ writer.write_frame(output)
271
+
272
+ torch.cuda.synchronize(device)
273
+ pbar.update(1)
274
+
275
+ reader.close()
276
+ writer.close()
277
+
278
+
279
+ def run(args):
280
+ args.video_name = osp.splitext(os.path.basename(args.input))[0]
281
+ video_save_path = osp.join(args.output, f'{args.video_name}_{args.suffix}.mp4')
282
+
283
+ if args.extract_frame_first:
284
+ tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
285
+ os.makedirs(tmp_frames_folder, exist_ok=True)
286
+ os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png')
287
+ args.input = tmp_frames_folder
288
+
289
+ num_gpus = torch.cuda.device_count()
290
+ num_process = num_gpus * args.num_process_per_gpu
291
+ if num_process == 1:
292
+ inference_video(args, video_save_path)
293
+ return
294
+
295
+ ctx = torch.multiprocessing.get_context('spawn')
296
+ pool = ctx.Pool(num_process)
297
+ os.makedirs(osp.join(args.output, f'{args.video_name}_out_tmp_videos'), exist_ok=True)
298
+ pbar = tqdm(total=num_process, unit='sub_video', desc='inference')
299
+ for i in range(num_process):
300
+ sub_video_save_path = osp.join(args.output, f'{args.video_name}_out_tmp_videos', f'{i:03d}.mp4')
301
+ pool.apply_async(
302
+ inference_video,
303
+ args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),
304
+ callback=lambda arg: pbar.update(1))
305
+ pool.close()
306
+ pool.join()
307
+
308
+ # combine sub videos
309
+ # prepare vidlist.txt
310
+ with open(f'{args.output}/{args.video_name}_vidlist.txt', 'w') as f:
311
+ for i in range(num_process):
312
+ f.write(f'file \'{args.video_name}_out_tmp_videos/{i:03d}.mp4\'\n')
313
+
314
+ cmd = [
315
+ args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/{args.video_name}_vidlist.txt', '-c',
316
+ 'copy', f'{video_save_path}'
317
+ ]
318
+ print(' '.join(cmd))
319
+ subprocess.call(cmd)
320
+ shutil.rmtree(osp.join(args.output, f'{args.video_name}_out_tmp_videos'))
321
+ if osp.exists(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')):
322
+ shutil.rmtree(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'))
323
+ os.remove(f'{args.output}/{args.video_name}_vidlist.txt')
324
+
325
+
326
+ def main():
327
+ """Inference demo for Real-ESRGAN.
328
+ It mainly for restoring anime videos.
329
+
330
+ """
331
+ parser = argparse.ArgumentParser()
332
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video, image or folder')
333
+ parser.add_argument(
334
+ '-n',
335
+ '--model_name',
336
+ type=str,
337
+ default='realesr-animevideov3',
338
+ help=('Model names: realesr-animevideov3 | RealESRGAN_x4plus_anime_6B | RealESRGAN_x4plus | RealESRNet_x4plus |'
339
+ ' RealESRGAN_x2plus | realesr-general-x4v3'
340
+ 'Default:realesr-animevideov3'))
341
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
342
+ parser.add_argument(
343
+ '-dn',
344
+ '--denoise_strength',
345
+ type=float,
346
+ default=0.5,
347
+ help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
348
+ 'Only used for the realesr-general-x4v3 model'))
349
+ parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
350
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video')
351
+ parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
352
+ parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
353
+ parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
354
+ parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
355
+ parser.add_argument(
356
+ '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
357
+ parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
358
+ parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg')
359
+ parser.add_argument('--extract_frame_first', action='store_true')
360
+ parser.add_argument('--num_process_per_gpu', type=int, default=4)
361
+
362
+ parser.add_argument(
363
+ '--alpha_upsampler',
364
+ type=str,
365
+ default='realesrgan',
366
+ help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
367
+ parser.add_argument(
368
+ '--ext',
369
+ type=str,
370
+ default='auto',
371
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
372
+ args = parser.parse_args()
373
+
374
+ args.input = args.input.rstrip('/').rstrip('\\')
375
+ os.makedirs(args.output, exist_ok=True)
376
+
377
+ if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):
378
+ is_video = True
379
+ else:
380
+ is_video = False
381
+
382
+ if is_video and args.input.endswith('.flv'):
383
+ mp4_path = args.input.replace('.flv', '.mp4')
384
+ os.system(f'ffmpeg -i {args.input} -codec copy {mp4_path}')
385
+ args.input = mp4_path
386
+
387
+ if args.extract_frame_first and not is_video:
388
+ args.extract_frame_first = False
389
+
390
+ run(args)
391
+
392
+ if args.extract_frame_first:
393
+ tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
394
+ shutil.rmtree(tmp_frames_folder)
395
+
396
+
397
+ if __name__ == '__main__':
398
+ main()