jounery-d commited on
Commit
fdaf566
·
verified ·
1 Parent(s): 98c15f6

Update run_axmodel.py

Browse files
Files changed (1) hide show
  1. run_axmodel.py +199 -199
run_axmodel.py CHANGED
@@ -1,199 +1,199 @@
1
- import argparse
2
- import cv2
3
- import glob
4
- import os
5
- import math
6
- import numpy as np
7
- import axengine as axe
8
-
9
- def pre_process(img, tile_size=108, tile_pad=10):
10
- """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
11
- """
12
- # mod pad for divisible borders
13
- pad_h, pad_w = 0, 0
14
- h, w = img.shape[0:2]
15
-
16
- if h % tile_size != 0:
17
- pad_h = (tile_size - h % tile_size)
18
- if w % tile_size != 0:
19
- pad_w = (tile_size - w % tile_size)
20
- img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') #mode='reflect')
21
-
22
- # boundary pad
23
- img = np.pad(img, ((tile_pad, tile_pad), (tile_pad, tile_pad), (0, 0)), 'constant')
24
-
25
- # to CHW-Batch format
26
- img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
27
-
28
- return img
29
-
30
- def tile_process(img, origin_shape, model, scale=2, tile_size=108, tile_pad=10, imgname=None):
31
- """It will first crop input images to tiles, and then process each tile.
32
- Finally, all the processed tiles are merged into one images.
33
- """
34
-
35
- # determine model paths
36
- if not os.path.exists(model):
37
- raise ValueError(f'Model {model} does not exist.')
38
-
39
- session = axe.InferenceSession(model)
40
- input_name = session.get_inputs()[0].name
41
- output_names = [x.name for x in session.get_outputs()]
42
-
43
- # tile
44
- batch, channel, height, width = img.shape
45
- output_height = int(round(height * scale))
46
- output_width = int(round(width * scale))
47
- output_shape = (batch, channel, output_height, output_width)
48
- origin_h, origin_w = origin_shape[0:2]
49
-
50
- # start with black image
51
- output = np.zeros(output_shape)
52
- tiles_x = math.floor(width / tile_size)
53
- tiles_y = math.floor(height / tile_size)
54
- print(f'Tile {tiles_x} x {tiles_y} for image {imgname}')
55
-
56
- # loop over all tiles
57
- for y in range(tiles_y):
58
- for x in range(tiles_x):
59
- # extract tile from input image
60
- ofs_x = x * tile_size
61
- ofs_y = y * tile_size
62
- # input tile area on total image
63
- input_start_x = ofs_x
64
- input_end_x = min(ofs_x + tile_size, width)
65
- input_start_y = ofs_y
66
- input_end_y = min(ofs_y + tile_size, height)
67
-
68
- # input tile dimensions
69
- input_tile = img[:, :, input_start_y:(input_end_y+2*tile_pad),
70
- input_start_x:(input_end_x+2*tile_pad)]
71
-
72
- # upscale tile
73
- try:
74
- output_tile = session.run(output_names, {input_name: input_tile})
75
- except RuntimeError as error:
76
- print('Error', error)
77
- #print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
78
-
79
- # output tile area on total image
80
- output_start_x = int(round(input_start_x * scale))
81
- output_end_x = int(round(input_end_x * scale))
82
- output_start_y = int(round(input_start_y * scale))
83
- output_end_y = int(round(input_end_y * scale))
84
-
85
- start_tile = int(round(tile_pad * scale))
86
- end_tile = int(round(tile_size * scale)) + start_tile
87
-
88
- output[:, :, output_start_y:output_end_y,
89
- output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile]
90
-
91
- # remove extra padding parts
92
- output = output[:, :, :int(round(origin_h * scale)), :int(round(origin_w * scale))].squeeze(0)
93
- output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32)
94
-
95
- return output
96
-
97
- def main():
98
- """Inference demo for Real-ESRGAN.
99
- """
100
- parser = argparse.ArgumentParser()
101
- parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
102
- parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
103
- parser.add_argument('-s', '--outscale', type=float, default=2, help='The final upsampling scale of the image, [Option:2, 4]')
104
- parser.add_argument(
105
- '--model_path', type=str, default=None, help='Model path. you need to specify it [Options: ]')
106
- parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
107
- parser.add_argument('-t', '--tile', type=int, default=108, help='Tile size, 0 for no tile during testing')
108
- parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding, (tile + tile_pad must == 128.)')
109
- parser.add_argument(
110
- '--ext',
111
- type=str,
112
- default='auto',
113
- help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
114
-
115
- args = parser.parse_args()
116
-
117
- # shape check
118
- assert (args.tile + 2*args.tile_pad) == 128, 'the model input size: 128.'
119
-
120
- # input
121
- if os.path.isfile(args.input):
122
- paths = [args.input]
123
- else:
124
- paths = sorted(glob.glob(os.path.join(args.input, '*')))
125
-
126
- # output
127
- os.makedirs(args.output, exist_ok=True)
128
-
129
- for idx, path in enumerate(paths):
130
- imgname, extension = os.path.splitext(os.path.basename(path))
131
- print('Testing', idx, imgname)
132
- if extension not in ['.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp']:
133
- continue
134
-
135
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
136
- if img is None:
137
- print('Error loading image')
138
- continue
139
- img = img.astype(np.float32)
140
- if np.max(img) > 256: # 16-bit image
141
- max_range = 65535
142
- print('\tInput is a 16-bit image')
143
- else:
144
- max_range = 255
145
- img = img / max_range
146
- if len(img.shape) == 2: # gray image
147
- img_mode = 'L'
148
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
149
- elif img.shape[2] == 4: # RGBA image with alpha channel
150
- img_mode = 'RGBA'
151
- alpha = img[:, :, 3]
152
- img = img[:, :, 0:3]
153
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
154
- else:
155
- img_mode = 'RGB'
156
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
157
-
158
- # pre-process
159
- origin_shape = img.shape
160
- img = pre_process(img, args.tile)
161
-
162
- # tile process
163
- try:
164
- output_img = tile_process(img, origin_shape, args.model_path, args.outscale, args.tile, args.tile_pad, imgname)
165
- except RuntimeError as error:
166
- print('Error', error)
167
- print('If you encounter out of memory, try to set --tile with a smaller number.')
168
-
169
- if img_mode == 'L':
170
- output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
171
- if img_mode == 'RGBA':
172
- h, w = alpha.shape[0:2]
173
- output_alpha = cv2.resize(
174
- alpha,
175
- (int(round(w * args.outscale)),
176
- int(round(h * args.outscale))),
177
- interpolation=cv2.INTER_LINEAR
178
- )
179
- output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
180
- output_img[:, :, 3] = output_alpha
181
-
182
- if max_range == 65535: # 16-bit image
183
- output = np.clip((output_img * 65535.0), 0, 65535).astype(np.uint16)
184
- else:
185
- output = np.clip((output_img * 255.0), 0, 255).round().astype(np.uint8)
186
-
187
- if args.ext == 'auto':
188
- extension = extension[1:]
189
- else:
190
- extension = args.ext
191
-
192
- if args.suffix == '':
193
- save_path = os.path.join(args.output, f'{imgname}.{extension}')
194
- else:
195
- save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
196
- cv2.imwrite(save_path, output)
197
-
198
- if __name__ == '__main__':
199
- main()
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import os
5
+ import math
6
+ import numpy as np
7
+ import axengine as axe
8
+
9
+ def pre_process(img, tile_size=108, tile_pad=10):
10
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
11
+ """
12
+ # mod pad for divisible borders
13
+ pad_h, pad_w = 0, 0
14
+ h, w = img.shape[0:2]
15
+
16
+ if h % tile_size != 0:
17
+ pad_h = (tile_size - h % tile_size)
18
+ if w % tile_size != 0:
19
+ pad_w = (tile_size - w % tile_size)
20
+ img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') #mode='reflect')
21
+
22
+ # boundary pad
23
+ img = np.pad(img, ((tile_pad, tile_pad), (tile_pad, tile_pad), (0, 0)), 'constant')
24
+
25
+ # to CHW-Batch format
26
+ img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
27
+
28
+ return img
29
+
30
+ def tile_process(img, origin_shape, model, scale=2, tile_size=108, tile_pad=10, imgname=None):
31
+ """It will first crop input images to tiles, and then process each tile.
32
+ Finally, all the processed tiles are merged into one images.
33
+ """
34
+
35
+ # determine model paths
36
+ if not os.path.exists(model):
37
+ raise ValueError(f'Model {model} does not exist.')
38
+
39
+ session = axe.InferenceSession(model)
40
+ input_name = session.get_inputs()[0].name
41
+ output_names = [x.name for x in session.get_outputs()]
42
+
43
+ # tile
44
+ batch, channel, height, width = img.shape
45
+ output_height = int(round(height * scale))
46
+ output_width = int(round(width * scale))
47
+ output_shape = (batch, channel, output_height, output_width)
48
+ origin_h, origin_w = origin_shape[0:2]
49
+
50
+ # start with black image
51
+ output = np.zeros(output_shape)
52
+ tiles_x = math.floor(width / tile_size)
53
+ tiles_y = math.floor(height / tile_size)
54
+ print(f'Tile {tiles_x} x {tiles_y} for image {imgname}')
55
+
56
+ # loop over all tiles
57
+ for y in range(tiles_y):
58
+ for x in range(tiles_x):
59
+ # extract tile from input image
60
+ ofs_x = x * tile_size
61
+ ofs_y = y * tile_size
62
+ # input tile area on total image
63
+ input_start_x = ofs_x
64
+ input_end_x = min(ofs_x + tile_size, width)
65
+ input_start_y = ofs_y
66
+ input_end_y = min(ofs_y + tile_size, height)
67
+
68
+ # input tile dimensions
69
+ input_tile = img[:, :, input_start_y:(input_end_y+2*tile_pad),
70
+ input_start_x:(input_end_x+2*tile_pad)]
71
+
72
+ # upscale tile
73
+ try:
74
+ output_tile = session.run(output_names, {input_name: input_tile})
75
+ except RuntimeError as error:
76
+ print('Error', error)
77
+ #print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
78
+
79
+ # output tile area on total image
80
+ output_start_x = int(round(input_start_x * scale))
81
+ output_end_x = int(round(input_end_x * scale))
82
+ output_start_y = int(round(input_start_y * scale))
83
+ output_end_y = int(round(input_end_y * scale))
84
+
85
+ start_tile = int(round(tile_pad * scale))
86
+ end_tile = int(round(tile_size * scale)) + start_tile
87
+
88
+ output[:, :, output_start_y:output_end_y,
89
+ output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile]
90
+
91
+ # remove extra padding parts
92
+ output = output[:, :, :int(round(origin_h * scale)), :int(round(origin_w * scale))].squeeze(0)
93
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32)
94
+
95
+ return output
96
+
97
+ def main():
98
+ """Inference demo for Real-ESRGAN.
99
+ """
100
+ parser = argparse.ArgumentParser()
101
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
102
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
103
+ parser.add_argument('-s', '--outscale', type=float, default=2, help='The final upsampling scale of the image, [Option:2, 4]')
104
+ parser.add_argument(
105
+ '--model_path', type=str, default=None, help='Model path. you need to specify it [Options: ]')
106
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
107
+ parser.add_argument('-t', '--tile', type=int, default=108, help='Tile size, 0 for no tile during testing')
108
+ parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding, (tile + tile_pad must == 128.)')
109
+ parser.add_argument(
110
+ '--ext',
111
+ type=str,
112
+ default='auto',
113
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
114
+
115
+ args = parser.parse_args()
116
+
117
+ # shape check
118
+ assert (args.tile + 2*args.tile_pad) == 128, 'the model input size: 128.'
119
+
120
+ # input
121
+ if os.path.isfile(args.input):
122
+ paths = [args.input]
123
+ else:
124
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
125
+
126
+ # output
127
+ os.makedirs(args.output, exist_ok=True)
128
+
129
+ for idx, path in enumerate(paths):
130
+ imgname, extension = os.path.splitext(os.path.basename(path))
131
+ print('Testing', idx, imgname)
132
+ if extension not in ['.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp']:
133
+ continue
134
+
135
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
136
+ if img is None:
137
+ print('Error loading image')
138
+ continue
139
+ img = img.astype(np.float32)
140
+ if np.max(img) > 256: # 16-bit image
141
+ max_range = 65535
142
+ print('\tInput is a 16-bit image')
143
+ else:
144
+ max_range = 255
145
+ img = img / max_range
146
+ if len(img.shape) == 2: # gray image
147
+ img_mode = 'L'
148
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
149
+ elif img.shape[2] == 4: # RGBA image with alpha channel
150
+ img_mode = 'RGBA'
151
+ alpha = img[:, :, 3]
152
+ img = img[:, :, 0:3]
153
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
154
+ else:
155
+ img_mode = 'RGB'
156
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
157
+
158
+ # pre-process
159
+ origin_shape = img.shape
160
+ img = pre_process(img, args.tile)
161
+
162
+ # tile process
163
+ try:
164
+ output_img = tile_process(img, origin_shape, args.model_path, args.outscale, args.tile, args.tile_pad, imgname)
165
+ except RuntimeError as error:
166
+ print('Error', error)
167
+ print('If you encounter out of memory, try to set --tile with a smaller number.')
168
+
169
+ if img_mode == 'L':
170
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
171
+ if img_mode == 'RGBA':
172
+ h, w = alpha.shape[0:2]
173
+ output_alpha = cv2.resize(
174
+ alpha,
175
+ (int(round(w * args.outscale)),
176
+ int(round(h * args.outscale))),
177
+ interpolation=cv2.INTER_LINEAR
178
+ )
179
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
180
+ output_img[:, :, 3] = output_alpha
181
+
182
+ if max_range == 65535: # 16-bit image
183
+ output = np.clip((output_img * 65535.0), 0, 65535).astype(np.uint16)
184
+ else:
185
+ output = np.clip((output_img * 255.0), 0, 255).astype(np.uint8)
186
+
187
+ if args.ext == 'auto':
188
+ extension = extension[1:]
189
+ else:
190
+ extension = args.ext
191
+
192
+ if args.suffix == '':
193
+ save_path = os.path.join(args.output, f'{imgname}.{extension}')
194
+ else:
195
+ save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
196
+ cv2.imwrite(save_path, output)
197
+
198
+ if __name__ == '__main__':
199
+ main()