jounery-d commited on
Commit
0987da9
·
verified ·
1 Parent(s): 1e051ab

Upload run_video.py

Browse files
Files changed (1) hide show
  1. run_video.py +172 -0
run_video.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import time
5
+ import math
6
+ import argparse
7
+ import numpy as np
8
+ import axengine as axe
9
+ from tqdm import tqdm
10
+
11
+ def from_numpy(x):
12
+ return x if isinstance(x, np.ndarray) else np.array(x)
13
+
14
+ class VideoTester():
15
+ def __init__(self, scale, tile=108, tile_pad=10, model=None, source=None):
16
+ self.scale = scale
17
+ self.tile = tile
18
+ self.tile_pad = tile_pad
19
+ self.session = axe.InferenceSession(model)
20
+ self.output_names = [x.name for x in self.session.get_outputs()]
21
+ self.input_name = self.session.get_inputs()[0].name
22
+ self.dir_demo = source
23
+ self.filename, _ = os.path.splitext(os.path.basename(self.dir_demo))
24
+
25
+ def pre_process(self, img):
26
+ # mod tile_pad for divisible borders
27
+ tile_pad_h, tile_pad_w = 0, 0
28
+ h, w = img.shape[0:2]
29
+
30
+ if h % self.tile != 0:
31
+ tile_pad_h = (self.tile - h % self.tile)
32
+ if w % self.tile != 0:
33
+ tile_pad_w = (self.tile - w % self.tile)
34
+ img = np.pad(img, ((0, tile_pad_h), (0, tile_pad_w), (0, 0)), 'constant') #mode='reflect')
35
+
36
+ # boundary tile_pad
37
+ img = np.pad(img, ((self.tile_pad, self.tile_pad), (self.tile_pad, self.tile_pad), (0, 0)), 'constant')
38
+
39
+ # to CHW-Batch format
40
+ img = (img[..., [2,1,0]] / 255).astype(np.float32)
41
+ img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
42
+
43
+ return img
44
+
45
+ def tile_process(self, img, origin_shape, imgname=None):
46
+ """It will first crop input images to tiles, and then process each tile.
47
+ Finally, all the processed tiles are merged into one images.
48
+ """
49
+ # tile
50
+ batch, channel, height, width = img.shape
51
+ output_height = int(round(height * self.scale))
52
+ output_width = int(round(width * self.scale))
53
+ output_shape = (batch, channel, output_height, output_width)
54
+ origin_w, origin_h = origin_shape[0:2]
55
+
56
+ # start with black image
57
+ output = np.zeros(output_shape)
58
+ tiles_x = math.floor(width / self.tile)
59
+ tiles_y = math.floor(height / self.tile)
60
+ #print(f'Tile {tiles_x} x {tiles_y} for image {imgname}')
61
+
62
+ start_tile = int(round(self.tile_pad * self.scale))
63
+ end_tile = int(round(self.tile * self.scale)) + start_tile
64
+
65
+ # loop over all tiles
66
+ for y in range(tiles_y):
67
+ for x in range(tiles_x):
68
+ # extract tile from input image
69
+ ofs_x = x * self.tile
70
+ ofs_y = y * self.tile
71
+ # input tile area on total image
72
+ input_start_x = ofs_x
73
+ input_end_x = min(ofs_x + self.tile, width)
74
+ input_start_y = ofs_y
75
+ input_end_y = min(ofs_y + self.tile, height)
76
+
77
+ # input tile dimensions
78
+ input_tile = img[:, :, input_start_y:(input_end_y+2*self.tile_pad),
79
+ input_start_x:(input_end_x+2*self.tile_pad)]
80
+
81
+ # upscale tile
82
+ try:
83
+ output_tile = self.session.run(self.output_names, {self.input_name: input_tile})
84
+ except RuntimeError as error:
85
+ print('Error', error)
86
+ #print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
87
+
88
+ # output tile area on total image
89
+ output_start_x = int(round(input_start_x * self.scale))
90
+ output_end_x = int(round(input_end_x * self.scale))
91
+ output_start_y = int(round(input_start_y * self.scale))
92
+ output_end_y = int(round(input_end_y * self.scale))
93
+
94
+ output[:, :, output_start_y:output_end_y,
95
+ output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile]
96
+
97
+ # remove extra tile_padding parts
98
+ output = output[:, :, :int(round(origin_h * self.scale)), :int(round(origin_w * self.scale))].squeeze(0)
99
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32)
100
+
101
+ return output
102
+
103
+ def test(self):
104
+ ''' test video
105
+ '''
106
+ vidcap = cv2.VideoCapture(self.dir_demo)
107
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
108
+ vid_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
109
+ vid_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
110
+
111
+ vidwri = cv2.VideoWriter(
112
+ os.path.join('results', ('{}_x{}.avi'.format(self.filename, self.scale))),
113
+ cv2.VideoWriter_fourcc(*'XVID'),
114
+ vidcap.get(cv2.CAP_PROP_FPS),
115
+ (
116
+ int(self.scale * vid_width),
117
+ int(self.scale * vid_height)
118
+ )
119
+ )
120
+
121
+ total_times = 0
122
+ tqdm_test = tqdm(range(total_frames), ncols=80)
123
+ for _ in tqdm_test:
124
+ success, frame = vidcap.read()
125
+ if not success: break
126
+ start_time = time.time()
127
+
128
+ frame = self.pre_process(frame)
129
+ sr_image = self.tile_process(frame, (vid_width, vid_height), self.filename)
130
+
131
+ end_time = time.time()
132
+ total_times += end_time - start_time
133
+
134
+ sr_image = np.clip(sr_image * 255, 0, 255).astype(np.uint8)
135
+ vidwri.write(sr_image)
136
+
137
+ print('Total time: {:.3f} seconds for {} frames'.format(total_times, total_frames))
138
+ print('Average time: {:.3f} seconds for each frame'.format(total_times / total_frames))
139
+
140
+ vidcap.release()
141
+ vidwri.release()
142
+
143
+ def main():
144
+ """Inference video for Real-ESRGAN.
145
+ """
146
+ parser = argparse.ArgumentParser()
147
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video or folder')
148
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
149
+ parser.add_argument('-s', '--scale', type=float, default=2, help='The final upsampling scale of the video, [Option:2, 4]')
150
+ parser.add_argument('-m', '--model', type=str, default=None, help='Model path. you need to specify it [Options: ]')
151
+ parser.add_argument('-t', '--tile', type=int, default=108, help='Tile size, 0 for no tile during testing')
152
+ parser.add_argument('-p', '--tile_pad', type=int, default=10, help='Tile tile_padding, (tile + tile_pad must == 128.)')
153
+
154
+ args = parser.parse_args()
155
+
156
+ # shape check
157
+ assert (args.tile + 2*args.tile_pad) == 128, 'the model input size: 128.'
158
+
159
+ # input
160
+ if not os.path.isfile(args.input):
161
+ raise ValueError(f'--input {args.input} is not a valid file.')
162
+
163
+ # output
164
+ os.makedirs(args.output, exist_ok=True)
165
+
166
+ # test
167
+ t = VideoTester(args.scale, args.tile, args.tile_pad, args.model, args.input)
168
+ t.test()
169
+
170
+
171
+ if __name__ == '__main__':
172
+ main()