jounery-d commited on
Commit
cfa21ff
·
verified ·
1 Parent(s): 0e336ee

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/realesrgan-x2.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ model/realesrgan-x4.axmodel filter=lfs diff=lfs merge=lfs -text
38
+ pics/00003.png filter=lfs diff=lfs merge=lfs -text
39
+ pics/children-alpha.png filter=lfs diff=lfs merge=lfs -text
40
+ pics/OST_009.png filter=lfs diff=lfs merge=lfs -text
41
+ pics/tree_alpha_16bit.png filter=lfs diff=lfs merge=lfs -text
42
+ results/1.png filter=lfs diff=lfs merge=lfs -text
43
+ results/2.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ base_model:
6
+ - Real-ESRGAN
7
+ pipeline_tag: frame
8
+ tags:
9
+ - Image
10
+ - SuperResolution
11
+ ---
12
+
13
+ # Real-ESRGAN
14
+
15
+ This version of Real-ESRGAN has been converted to run on the Axera NPU using **w8a8** quantization.
16
+
17
+ This model has been optimized with the following LoRA:
18
+
19
+ Compatible with Pulsar2 version: 4.2
20
+
21
+ ## Convert tools links:
22
+
23
+ For those who are interested in model conversion, you can try to export axmodel through
24
+
25
+ - [The repo of AXera Platform](https://github.com/AXERA-TECH/ax-samples), which you can get the detail of guide
26
+
27
+ - [Pulsar2 Link, How to Convert ONNX to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/pulsar2/introduction.html)
28
+
29
+
30
+ ## Support Platform
31
+
32
+ - AX650
33
+ - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
34
+ - [M.2 Accelerator card](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
35
+ - AX630C
36
+ - [爱芯派2](https://axera-pi-2-docs-cn.readthedocs.io/zh-cn/latest/index.html)
37
+ - [Module-LLM](https://docs.m5stack.com/zh_CN/module/Module-LLM)
38
+ - [LLM630 Compute Kit](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
39
+
40
+ |Chips|model|cost|
41
+ |--|--|--|
42
+ |AX650|realesrgan-x2|15.6 ms|
43
+ |AX650|realesrgan-x4|62.1 ms|
44
+
45
+ ## How to use
46
+
47
+ Download all files from this repository to the device
48
+
49
+ ```
50
+
51
+ root@ax650:~/realesrgan# tree
52
+ .
53
+ |-- model
54
+ | `-- realesrgan-x2.axmodel
55
+ | `-- realesrgan-x4.axmodel
56
+ |`-- run_onnx.py
57
+ |`-- run_axmodel.py
58
+ |`-- build_config.json
59
+ |`-- requirements.txt
60
+
61
+
62
+
63
+ ```
64
+
65
+ ### Inference
66
+
67
+ Input Data:
68
+ |-- video
69
+ | `-- demo.mp4
70
+
71
+ #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
72
+
73
+ ```
74
+ root@ax650 ~/realesrgan #python3 run_axmodel.py --input ./pics --outscale 2 --model_path ./realesrgan-x2.axmodel
75
+ [INFO] Available providers: ['AxEngineExecutionProvider']
76
+ Testing 0 00003
77
+ [INFO] Using provider: AxEngineExecutionProvider
78
+ [INFO] Chip type: ChipType.MC50
79
+ [INFO] VNPU type: VNPUType.DISABLED
80
+ [INFO] Engine version: 2.12.0s
81
+ [INFO] Model type: 2 (triple core)
82
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
83
+ Testing 1 00017_gray
84
+ [INFO] Using provider: AxEngineExecutionProvider
85
+ [INFO] Model type: 2 (triple core)
86
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
87
+ Testing 2 0014
88
+ [INFO] Using provider: AxEngineExecutionProvider
89
+ [INFO] Model type: 2 (triple core)
90
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
91
+ Testing 3 0030
92
+ [INFO] Using provider: AxEngineExecutionProvider
93
+ [INFO] Model type: 2 (triple core)
94
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
95
+ Testing 4 ADE_val_00000114
96
+ [INFO] Using provider: AxEngineExecutionProvider
97
+ [INFO] Model type: 2 (triple core)
98
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
99
+ Testing 5 OST_009
100
+ [INFO] Using provider: AxEngineExecutionProvider
101
+ [INFO] Model type: 2 (triple core)
102
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
103
+ Testing 6 children-alpha
104
+ [INFO] Using provider: AxEngineExecutionProvider
105
+ [INFO] Model type: 2 (triple core)
106
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
107
+ Testing 7 tree_alpha_16bit
108
+ Input is a 16-bit image
109
+ [INFO] Using provider: AxEngineExecutionProvider
110
+ [INFO] Model type: 2 (triple core)
111
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
112
+ Testing 8 wolf_gray
113
+ [INFO] Using provider: AxEngineExecutionProvider
114
+ [INFO] Model type: 2 (triple core)
115
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
116
+
117
+ ```
118
+
119
+ Output:
120
+ [INFO]:
build_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "work_dir": "",
3
+ "model_type": "ONNX",
4
+ "target_hardware": "AX650",
5
+ "npu_mode": "NPU3",
6
+ "onnx_opt": {
7
+ "disable_onnx_optimization": false,
8
+ "model_check": false,
9
+ },
10
+ "quant": {
11
+ "input_configs": [
12
+ {
13
+ "tensor_name": "DEFAULT",
14
+ "calibration_dataset": "npy.zip",
15
+ "calibration_format": "Numpy",
16
+ "calibration_size": 10,
17
+ "calibration_mean": [0, 0, 0],
18
+ "calibration_std": [1.0, 1.0, 1.0]
19
+ }
20
+ ],
21
+ "calibration_method": "MinMax",
22
+ "precision_analysis": true,
23
+ "precision_analysis_method": "EndToEnd",
24
+ "precision_analysis_mode": "Reference"
25
+ },
26
+ "input_processors": [
27
+ {
28
+ "tensor_name": "DEFAULT",
29
+ "tensor_format": "AutoColorSpace",
30
+ "tensor_layout": "NCHW",
31
+ "src_layout": "NCHW",
32
+ "src_format": "AutoColorSpace",
33
+ "src_dtype": "FP32",
34
+ }
35
+ ],
36
+ "output_processors": [
37
+ {
38
+ "tensor_name": "DEFAULT",
39
+ }
40
+ ],
41
+ "compiler": {
42
+ "check": 0
43
+ }
44
+ }
45
+
model/realesrgan-x2.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:348f9e4d81072b4865cb1d96143134f0de44d3f2c750805b188a5c42ba5d633e
3
+ size 19270519
model/realesrgan-x4.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:301a65723e740bbc7082b84d1622ff2555ba732baf6f19373c5b8c9e1e03fb75
3
+ size 19657802
pics/00003.png ADDED

Git LFS Details

  • SHA256: d37932ae7d3137a0e38f8a90f7e3e16e13353399db6e29dca5a03a350f5fed1b
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
pics/00017_gray.png ADDED
pics/0014.jpg ADDED
pics/0030.jpg ADDED
pics/ADE_val_00000114.jpg ADDED
pics/OST_009.png ADDED

Git LFS Details

  • SHA256: 62c8ec34919070f9c6fd3398d7a863b4d214adb4822331e9d507317b683ef46d
  • Pointer size: 131 Bytes
  • Size of remote file: 718 kB
pics/children-alpha.png ADDED

Git LFS Details

  • SHA256: 17323c91483660079e2e95fce438485b8f144bbaee50b2e7b10a9c343c628589
  • Pointer size: 131 Bytes
  • Size of remote file: 275 kB
pics/tree_alpha_16bit.png ADDED

Git LFS Details

  • SHA256: e6af49641c52884f1d5af6f8afcc75fa2ee0c31fb8e60a37e907b62aeb30d660
  • Pointer size: 131 Bytes
  • Size of remote file: 382 kB
pics/wolf_gray.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ onnxruntime
results/1.png ADDED

Git LFS Details

  • SHA256: dcd449360c2bde7274b1a2c1005294c15b0c9b06295d8ad7d7ad0f45e78a0af1
  • Pointer size: 131 Bytes
  • Size of remote file: 296 kB
results/2.png ADDED

Git LFS Details

  • SHA256: 757a300fd65736d31d86c2b94d1282bca4f2171a69b074e89ea97f373ae55526
  • Pointer size: 131 Bytes
  • Size of remote file: 254 kB
run_axmodel.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import os
5
+ import math
6
+ import numpy as np
7
+ import axengine as axe
8
+
9
+ def pre_process(img, tile_size=128):
10
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
11
+ """
12
+ # mod pad for divisible borders
13
+ pad_h, pad_w = 0, 0
14
+ h, w = img.shape[0:2]
15
+
16
+ if h % tile_size != 0:
17
+ pad_h = (tile_size - h % tile_size)
18
+ if w % tile_size != 0:
19
+ pad_w = (tile_size - w % tile_size)
20
+ img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
21
+ img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
22
+
23
+ return img
24
+
25
+ def tile_process(img, origin_shape, model, scale=2, tile_size=64):
26
+ """It will first crop input images to tiles, and then process each tile.
27
+ Finally, all the processed tiles are merged into one images.
28
+ """
29
+
30
+ # determine model paths
31
+ if not os.path.exists(model):
32
+ raise ValueError(f'Model {model} does not exist.')
33
+
34
+ session = axe.InferenceSession(model)
35
+ input_name = session.get_inputs()[0].name
36
+ output_names = [x.name for x in session.get_outputs()]
37
+
38
+ # tile
39
+ batch, channel, height, width = img.shape
40
+ output_height = int(round(height * scale))
41
+ output_width = int(round(width * scale))
42
+ output_shape = (batch, channel, output_height, output_width)
43
+
44
+ # start with black image
45
+ output = np.zeros(output_shape)
46
+ tiles_x = math.ceil(width / tile_size)
47
+ tiles_y = math.ceil(height / tile_size)
48
+
49
+ # loop over all tiles
50
+ for y in range(tiles_y):
51
+ for x in range(tiles_x):
52
+ # extract tile from input image
53
+ ofs_x = x * tile_size
54
+ ofs_y = y * tile_size
55
+ # input tile area on total image
56
+ input_start_x = ofs_x
57
+ input_end_x = min(ofs_x + tile_size, width)
58
+ input_start_y = ofs_y
59
+ input_end_y = min(ofs_y + tile_size, height)
60
+
61
+ # input tile dimensions
62
+ tile_idx = y * tiles_x + x + 1
63
+ input_tile = img[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
64
+
65
+ # upscale tile
66
+ try:
67
+ output_tile = session.run(output_names, {input_name: input_tile})
68
+ except RuntimeError as error:
69
+ print('Error', error)
70
+ #print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
71
+
72
+ # output tile area on total image
73
+ output_start_x = int(round(input_start_x * scale))
74
+ output_end_x = int(round(input_end_x * scale))
75
+ output_start_y = int(round(input_start_y * scale))
76
+ output_end_y = int(round(input_end_y * scale))
77
+ output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[0]
78
+
79
+ # remove extra padding parts
80
+ origin_h, origin_w = origin_shape[0:2]
81
+ output = output[:, :, :int(round(origin_h * scale)), :int(round(origin_w * scale))].squeeze(0)
82
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32)
83
+
84
+ return output
85
+
86
+ def main():
87
+ """Inference demo for Real-ESRGAN.
88
+ """
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
91
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
92
+ parser.add_argument('-s', '--outscale', type=float, default=2, help='The final upsampling scale of the image, [Option:2, 4]')
93
+ parser.add_argument(
94
+ '--model_path', type=str, default=None, help='Model path. you need to specify it [Options: ]')
95
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
96
+ parser.add_argument('-t', '--tile', type=int, default=128, help='Tile size, 0 for no tile during testing')
97
+ parser.add_argument(
98
+ '--ext',
99
+ type=str,
100
+ default='auto',
101
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
102
+
103
+ args = parser.parse_args()
104
+
105
+ # input
106
+ if os.path.isfile(args.input):
107
+ paths = [args.input]
108
+ else:
109
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
110
+
111
+ # output
112
+ os.makedirs(args.output, exist_ok=True)
113
+
114
+ for idx, path in enumerate(paths):
115
+ imgname, extension = os.path.splitext(os.path.basename(path))
116
+ print('Testing', idx, imgname)
117
+ if extension not in ['.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp']:
118
+ continue
119
+
120
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
121
+ if img is None:
122
+ print('Error loading image')
123
+ continue
124
+ img = img.astype(np.float32)
125
+ if np.max(img) > 256: # 16-bit image
126
+ max_range = 65535
127
+ print('\tInput is a 16-bit image')
128
+ else:
129
+ max_range = 255
130
+ img = img / max_range
131
+ if len(img.shape) == 2: # gray image
132
+ img_mode = 'L'
133
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
134
+ elif img.shape[2] == 4: # RGBA image with alpha channel
135
+ img_mode = 'RGBA'
136
+ alpha = img[:, :, 3]
137
+ img = img[:, :, 0:3]
138
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
139
+ else:
140
+ img_mode = 'RGB'
141
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
142
+
143
+ # pre-process
144
+ origin_shape = img.shape
145
+ img = pre_process(img, args.tile)
146
+
147
+ # tile process
148
+ try:
149
+ output_img = tile_process(img, origin_shape, args.model_path, args.outscale, args.tile)
150
+ except RuntimeError as error:
151
+ print('Error', error)
152
+ print('If you encounter out of memory, try to set --tile with a smaller number.')
153
+
154
+ if img_mode == 'L':
155
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
156
+ if img_mode == 'RGBA':
157
+ h, w = alpha.shape[0:2]
158
+ output_alpha = cv2.resize(
159
+ alpha,
160
+ (int(round(w * args.outscale)),
161
+ int(round(h * args.outscale))),
162
+ interpolation=cv2.INTER_LINEAR
163
+ )
164
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
165
+ output_img[:, :, 3] = output_alpha
166
+
167
+ if max_range == 65535: # 16-bit image
168
+ output = np.clip((output_img * 65535.0), 0, 65535).astype(np.uint16)
169
+ else:
170
+ output = np.clip((output_img * 255.0), 0, 255).round().astype(np.uint8)
171
+
172
+ if args.ext == 'auto':
173
+ extension = extension[1:]
174
+ else:
175
+ extension = args.ext
176
+
177
+ if args.suffix == '':
178
+ save_path = os.path.join(args.output, f'{imgname}.{extension}')
179
+ else:
180
+ save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
181
+ cv2.imwrite(save_path, output)
182
+
183
+ if __name__ == '__main__':
184
+ main()
run_onnx.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import os
5
+ import math
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+
9
+ def pre_process(img, tile_size=128):
10
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
11
+ """
12
+ # mod pad for divisible borders
13
+ pad_h, pad_w = 0, 0
14
+ h, w = img.shape[0:2]
15
+
16
+ if h % tile_size != 0:
17
+ pad_h = (tile_size - h % tile_size)
18
+ if w % tile_size != 0:
19
+ pad_w = (tile_size - w % tile_size)
20
+ img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
21
+ img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
22
+
23
+ return img
24
+
25
+ def tile_process(img, origin_shape, model, scale=2, tile_size=64):
26
+ """It will first crop input images to tiles, and then process each tile.
27
+ Finally, all the processed tiles are merged into one images.
28
+ """
29
+
30
+ # determine model paths
31
+ if not os.path.exists(model):
32
+ raise ValueError(f'Model {model} does not exist.')
33
+
34
+ session = ort.InferenceSession(model)
35
+ input_name = session.get_inputs()[0].name
36
+ output_names = [x.name for x in session.get_outputs()]
37
+
38
+ # tile
39
+ batch, channel, height, width = img.shape
40
+ output_height = int(round(height * scale))
41
+ output_width = int(round(width * scale))
42
+ output_shape = (batch, channel, output_height, output_width)
43
+
44
+ # start with black image
45
+ output = np.zeros(output_shape)
46
+ tiles_x = math.ceil(width / tile_size)
47
+ tiles_y = math.ceil(height / tile_size)
48
+
49
+ # loop over all tiles
50
+ for y in range(tiles_y):
51
+ for x in range(tiles_x):
52
+ # extract tile from input image
53
+ ofs_x = x * tile_size
54
+ ofs_y = y * tile_size
55
+ # input tile area on total image
56
+ input_start_x = ofs_x
57
+ input_end_x = min(ofs_x + tile_size, width)
58
+ input_start_y = ofs_y
59
+ input_end_y = min(ofs_y + tile_size, height)
60
+
61
+ # input tile dimensions
62
+ tile_idx = y * tiles_x + x + 1
63
+ input_tile = img[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
64
+
65
+ # upscale tile
66
+ try:
67
+ output_tile = session.run(output_names, {input_name: input_tile})
68
+ except RuntimeError as error:
69
+ print('Error', error)
70
+ #print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
71
+
72
+ # output tile area on total image
73
+ output_start_x = int(round(input_start_x * scale))
74
+ output_end_x = int(round(input_end_x * scale))
75
+ output_start_y = int(round(input_start_y * scale))
76
+ output_end_y = int(round(input_end_y * scale))
77
+ output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[0]
78
+
79
+ # remove extra padding parts
80
+ origin_h, origin_w = origin_shape[0:2]
81
+ output = output[:, :, :int(round(origin_h * scale)), :int(round(origin_w * scale))].squeeze(0)
82
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32)
83
+
84
+ return output
85
+
86
+ def main():
87
+ """Inference demo for Real-ESRGAN.
88
+ """
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
91
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
92
+ parser.add_argument('-s', '--outscale', type=float, default=2, help='The final upsampling scale of the image, [Option:2, 4]')
93
+ parser.add_argument(
94
+ '--model_path', type=str, default=None, help='Model path. you need to specify it [Options: ]')
95
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
96
+ parser.add_argument('-t', '--tile', type=int, default=128, help='Tile size, 0 for no tile during testing')
97
+ parser.add_argument(
98
+ '--ext',
99
+ type=str,
100
+ default='auto',
101
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
102
+
103
+ args = parser.parse_args()
104
+
105
+ # input
106
+ if os.path.isfile(args.input):
107
+ paths = [args.input]
108
+ else:
109
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
110
+
111
+ # output
112
+ os.makedirs(args.output, exist_ok=True)
113
+
114
+ for idx, path in enumerate(paths):
115
+ imgname, extension = os.path.splitext(os.path.basename(path))
116
+ print('Testing', idx, imgname)
117
+ if extension not in ['.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp']:
118
+ continue
119
+
120
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
121
+ if img is None:
122
+ print('Error loading image')
123
+ continue
124
+ img = img.astype(np.float32)
125
+ if np.max(img) > 256: # 16-bit image
126
+ max_range = 65535
127
+ print('\tInput is a 16-bit image')
128
+ else:
129
+ max_range = 255
130
+ img = img / max_range
131
+ if len(img.shape) == 2: # gray image
132
+ img_mode = 'L'
133
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
134
+ elif img.shape[2] == 4: # RGBA image with alpha channel
135
+ img_mode = 'RGBA'
136
+ alpha = img[:, :, 3]
137
+ img = img[:, :, 0:3]
138
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
139
+ else:
140
+ img_mode = 'RGB'
141
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
142
+
143
+ # pre-process
144
+ origin_shape = img.shape
145
+ img = pre_process(img, args.tile)
146
+
147
+ # tile process
148
+ try:
149
+ output_img = tile_process(img, origin_shape, args.model_path, args.outscale, args.tile, imgname)
150
+ except RuntimeError as error:
151
+ print('Error', error)
152
+ print('If you encounter out of memory, try to set --tile with a smaller number.')
153
+
154
+ if img_mode == 'L':
155
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
156
+ if img_mode == 'RGBA':
157
+ h, w = alpha.shape[0:2]
158
+ output_alpha = cv2.resize(
159
+ alpha,
160
+ (int(round(w * args.outscale)),
161
+ int(round(h * args.outscale))),
162
+ interpolation=cv2.INTER_LINEAR
163
+ )
164
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
165
+ output_img[:, :, 3] = output_alpha
166
+
167
+ if max_range == 65535: # 16-bit image
168
+ output = np.clip((output_img * 65535.0), 0, 65535).astype(np.uint16)
169
+ else:
170
+ output = np.clip((output_img * 255.0), 0, 255).round().astype(np.uint8)
171
+
172
+ if args.ext == 'auto':
173
+ extension = extension[1:]
174
+ else:
175
+ extension = args.ext
176
+
177
+ if args.suffix == '':
178
+ save_path = os.path.join(args.output, f'{imgname}.{extension}')
179
+ else:
180
+ save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
181
+ cv2.imwrite(save_path, output)
182
+
183
+ if __name__ == '__main__':
184
+ main()