import os import cv2 import glob import time import math import argparse import numpy as np import axengine as axe from tqdm import tqdm def from_numpy(x): return x if isinstance(x, np.ndarray) else np.array(x) class VideoTester(): def __init__(self, scale, tile=108, tile_pad=10, model=None, source=None): self.scale = scale self.tile = tile self.tile_pad = tile_pad self.session = axe.InferenceSession(model) self.output_names = [x.name for x in self.session.get_outputs()] self.input_name = self.session.get_inputs()[0].name self.dir_demo = source self.filename, _ = os.path.splitext(os.path.basename(self.dir_demo)) def pre_process(self, img): # mod tile_pad for divisible borders tile_pad_h, tile_pad_w = 0, 0 h, w = img.shape[0:2] if h % self.tile != 0: tile_pad_h = (self.tile - h % self.tile) if w % self.tile != 0: tile_pad_w = (self.tile - w % self.tile) img = np.pad(img, ((0, tile_pad_h), (0, tile_pad_w), (0, 0)), 'constant') #mode='reflect') # boundary tile_pad img = np.pad(img, ((self.tile_pad, self.tile_pad), (self.tile_pad, self.tile_pad), (0, 0)), 'constant') # to CHW-Batch format img = (img[..., [2,1,0]] / 255).astype(np.float32) img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) return img def tile_process(self, img, origin_shape, imgname=None): """It will first crop input images to tiles, and then process each tile. Finally, all the processed tiles are merged into one images. """ # tile batch, channel, height, width = img.shape output_height = int(round(height * self.scale)) output_width = int(round(width * self.scale)) output_shape = (batch, channel, output_height, output_width) origin_w, origin_h = origin_shape[0:2] # start with black image output = np.zeros(output_shape) tiles_x = math.floor(width / self.tile) tiles_y = math.floor(height / self.tile) #print(f'Tile {tiles_x} x {tiles_y} for image {imgname}') start_tile = int(round(self.tile_pad * self.scale)) end_tile = int(round(self.tile * self.scale)) + start_tile # loop over all tiles for y in range(tiles_y): for x in range(tiles_x): # extract tile from input image ofs_x = x * self.tile ofs_y = y * self.tile # input tile area on total image input_start_x = ofs_x input_end_x = min(ofs_x + self.tile, width) input_start_y = ofs_y input_end_y = min(ofs_y + self.tile, height) # input tile dimensions input_tile = img[:, :, input_start_y:(input_end_y+2*self.tile_pad), input_start_x:(input_end_x+2*self.tile_pad)] # upscale tile try: output_tile = self.session.run(self.output_names, {self.input_name: input_tile}) except RuntimeError as error: print('Error', error) #print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') # output tile area on total image output_start_x = int(round(input_start_x * self.scale)) output_end_x = int(round(input_end_x * self.scale)) output_start_y = int(round(input_start_y * self.scale)) output_end_y = int(round(input_end_y * self.scale)) output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile] # remove extra tile_padding parts output = output[:, :, :int(round(origin_h * self.scale)), :int(round(origin_w * self.scale))].squeeze(0) output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32) return output def test(self): ''' test video ''' vidcap = cv2.VideoCapture(self.dir_demo) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) vid_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) vid_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) vidwri = cv2.VideoWriter( os.path.join('results', ('{}_x{}.avi'.format(self.filename, self.scale))), cv2.VideoWriter_fourcc(*'XVID'), vidcap.get(cv2.CAP_PROP_FPS), ( int(self.scale * vid_width), int(self.scale * vid_height) ) ) total_times = 0 tqdm_test = tqdm(range(total_frames), ncols=80) for _ in tqdm_test: success, frame = vidcap.read() if not success: break start_time = time.time() frame = self.pre_process(frame) sr_image = self.tile_process(frame, (vid_width, vid_height), self.filename) end_time = time.time() total_times += end_time - start_time sr_image = np.clip(sr_image * 255, 0, 255).astype(np.uint8) vidwri.write(sr_image) print('Total time: {:.3f} seconds for {} frames'.format(total_times, total_frames)) print('Average time: {:.3f} seconds for each frame'.format(total_times / total_frames)) vidcap.release() vidwri.release() def main(): """Inference video for Real-ESRGAN. """ parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video or folder') parser.add_argument('-o', '--output', type=str, default='results', help='Output folder') parser.add_argument('-s', '--scale', type=float, default=2, help='The final upsampling scale of the video, [Option:2, 4]') parser.add_argument('-m', '--model', type=str, default=None, help='Model path. you need to specify it [Options: ]') parser.add_argument('-t', '--tile', type=int, default=108, help='Tile size, 0 for no tile during testing') parser.add_argument('-p', '--tile_pad', type=int, default=10, help='Tile tile_padding, (tile + tile_pad must == 128.)') args = parser.parse_args() # shape check assert (args.tile + 2*args.tile_pad) == 128, 'the model input size: 128.' # input if not os.path.isfile(args.input): raise ValueError(f'--input {args.input} is not a valid file.') # output os.makedirs(args.output, exist_ok=True) # test t = VideoTester(args.scale, args.tile, args.tile_pad, args.model, args.input) t.test() if __name__ == '__main__': main()