|
|
import os
|
|
|
import cv2
|
|
|
import glob
|
|
|
import time
|
|
|
import math
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
import axengine as axe
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
def from_numpy(x):
|
|
|
return x if isinstance(x, np.ndarray) else np.array(x)
|
|
|
|
|
|
class VideoTester():
|
|
|
def __init__(self, scale, tile=108, tile_pad=10, model=None, source=None):
|
|
|
self.scale = scale
|
|
|
self.tile = tile
|
|
|
self.tile_pad = tile_pad
|
|
|
self.session = axe.InferenceSession(model)
|
|
|
self.output_names = [x.name for x in self.session.get_outputs()]
|
|
|
self.input_name = self.session.get_inputs()[0].name
|
|
|
self.dir_demo = source
|
|
|
self.filename, _ = os.path.splitext(os.path.basename(self.dir_demo))
|
|
|
|
|
|
def pre_process(self, img):
|
|
|
|
|
|
tile_pad_h, tile_pad_w = 0, 0
|
|
|
h, w = img.shape[0:2]
|
|
|
|
|
|
if h % self.tile != 0:
|
|
|
tile_pad_h = (self.tile - h % self.tile)
|
|
|
if w % self.tile != 0:
|
|
|
tile_pad_w = (self.tile - w % self.tile)
|
|
|
img = np.pad(img, ((0, tile_pad_h), (0, tile_pad_w), (0, 0)), 'constant')
|
|
|
|
|
|
|
|
|
img = np.pad(img, ((self.tile_pad, self.tile_pad), (self.tile_pad, self.tile_pad), (0, 0)), 'constant')
|
|
|
|
|
|
|
|
|
img = (img[..., [2,1,0]] / 255).astype(np.float32)
|
|
|
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
|
|
|
|
|
return img
|
|
|
|
|
|
def tile_process(self, img, origin_shape, imgname=None):
|
|
|
"""It will first crop input images to tiles, and then process each tile.
|
|
|
Finally, all the processed tiles are merged into one images.
|
|
|
"""
|
|
|
|
|
|
batch, channel, height, width = img.shape
|
|
|
output_height = int(round(height * self.scale))
|
|
|
output_width = int(round(width * self.scale))
|
|
|
output_shape = (batch, channel, output_height, output_width)
|
|
|
origin_w, origin_h = origin_shape[0:2]
|
|
|
|
|
|
|
|
|
output = np.zeros(output_shape)
|
|
|
tiles_x = math.floor(width / self.tile)
|
|
|
tiles_y = math.floor(height / self.tile)
|
|
|
|
|
|
|
|
|
start_tile = int(round(self.tile_pad * self.scale))
|
|
|
end_tile = int(round(self.tile * self.scale)) + start_tile
|
|
|
|
|
|
|
|
|
for y in range(tiles_y):
|
|
|
for x in range(tiles_x):
|
|
|
|
|
|
ofs_x = x * self.tile
|
|
|
ofs_y = y * self.tile
|
|
|
|
|
|
input_start_x = ofs_x
|
|
|
input_end_x = min(ofs_x + self.tile, width)
|
|
|
input_start_y = ofs_y
|
|
|
input_end_y = min(ofs_y + self.tile, height)
|
|
|
|
|
|
|
|
|
input_tile = img[:, :, input_start_y:(input_end_y+2*self.tile_pad),
|
|
|
input_start_x:(input_end_x+2*self.tile_pad)]
|
|
|
|
|
|
|
|
|
try:
|
|
|
output_tile = self.session.run(self.output_names, {self.input_name: input_tile})
|
|
|
except RuntimeError as error:
|
|
|
print('Error', error)
|
|
|
|
|
|
|
|
|
|
|
|
output_start_x = int(round(input_start_x * self.scale))
|
|
|
output_end_x = int(round(input_end_x * self.scale))
|
|
|
output_start_y = int(round(input_start_y * self.scale))
|
|
|
output_end_y = int(round(input_end_y * self.scale))
|
|
|
|
|
|
output[:, :, output_start_y:output_end_y,
|
|
|
output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile]
|
|
|
|
|
|
|
|
|
output = output[:, :, :int(round(origin_h * self.scale)), :int(round(origin_w * self.scale))].squeeze(0)
|
|
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32)
|
|
|
|
|
|
return output
|
|
|
|
|
|
def test(self):
|
|
|
''' test video
|
|
|
'''
|
|
|
vidcap = cv2.VideoCapture(self.dir_demo)
|
|
|
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
vid_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
vid_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
|
|
vidwri = cv2.VideoWriter(
|
|
|
os.path.join('results', ('{}_x{}.avi'.format(self.filename, self.scale))),
|
|
|
cv2.VideoWriter_fourcc(*'XVID'),
|
|
|
vidcap.get(cv2.CAP_PROP_FPS),
|
|
|
(
|
|
|
int(self.scale * vid_width),
|
|
|
int(self.scale * vid_height)
|
|
|
)
|
|
|
)
|
|
|
|
|
|
total_times = 0
|
|
|
tqdm_test = tqdm(range(total_frames), ncols=80)
|
|
|
for _ in tqdm_test:
|
|
|
success, frame = vidcap.read()
|
|
|
if not success: break
|
|
|
start_time = time.time()
|
|
|
|
|
|
frame = self.pre_process(frame)
|
|
|
sr_image = self.tile_process(frame, (vid_width, vid_height), self.filename)
|
|
|
|
|
|
end_time = time.time()
|
|
|
total_times += end_time - start_time
|
|
|
|
|
|
sr_image = np.clip(sr_image * 255, 0, 255).astype(np.uint8)
|
|
|
vidwri.write(sr_image)
|
|
|
|
|
|
print('Total time: {:.3f} seconds for {} frames'.format(total_times, total_frames))
|
|
|
print('Average time: {:.3f} seconds for each frame'.format(total_times / total_frames))
|
|
|
|
|
|
vidcap.release()
|
|
|
vidwri.release()
|
|
|
|
|
|
def main():
|
|
|
"""Inference video for Real-ESRGAN.
|
|
|
"""
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video or folder')
|
|
|
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
|
|
|
parser.add_argument('-s', '--scale', type=float, default=2, help='The final upsampling scale of the video, [Option:2, 4]')
|
|
|
parser.add_argument('-m', '--model', type=str, default=None, help='Model path. you need to specify it [Options: ]')
|
|
|
parser.add_argument('-t', '--tile', type=int, default=108, help='Tile size, 0 for no tile during testing')
|
|
|
parser.add_argument('-p', '--tile_pad', type=int, default=10, help='Tile tile_padding, (tile + tile_pad must == 128.)')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
assert (args.tile + 2*args.tile_pad) == 128, 'the model input size: 128.'
|
|
|
|
|
|
|
|
|
if not os.path.isfile(args.input):
|
|
|
raise ValueError(f'--input {args.input} is not a valid file.')
|
|
|
|
|
|
|
|
|
os.makedirs(args.output, exist_ok=True)
|
|
|
|
|
|
|
|
|
t = VideoTester(args.scale, args.tile, args.tile_pad, args.model, args.input)
|
|
|
t.test()
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|