jounery-d commited on
Commit
50628b2
·
verified ·
1 Parent(s): e4dbceb

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model_convert/axmodel/edsr_baseline_x2_1.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ model_convert/axmodel/espcn_x2_T9.axmodel filter=lfs diff=lfs merge=lfs -text
38
+ video/1.png filter=lfs diff=lfs merge=lfs -text
39
+ video/2.png filter=lfs diff=lfs merge=lfs -text
40
+ video/test_1920x1080.mp4 filter=lfs diff=lfs merge=lfs -text
model_convert/README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模型转换
2
+
3
+ ## 导出模型(ONNX)
4
+ 导出edsr onnx可以参考:https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/main.py
5
+
6
+ 在main.py加上如下代码,可以正常导出onnx:
7
+ ```
8
+ model = model.to('cpu')
9
+ target_onnx_file = './edsr_baseline_x2_1.onnx'
10
+ dummy_input = torch.randn(1, 3, 1080, 1920)
11
+ idx_scale = 0
12
+ torch.onnx.export(model,
13
+ (dummy_input, idx_scale),
14
+ target_onnx_file,
15
+ export_params=True,
16
+ opset_version=11,
17
+ do_constant_folding=True,
18
+ dynamic_axes = {},
19
+ )
20
+
21
+ print(f"Export model onnx to {target_onnx_file} finished")
22
+ ```
23
+ 这里固定onnx输入尺寸为:1x3x1080x1920
24
+
25
+ ## 动态onnx转静态
26
+ ```
27
+ onnxsim edsr_baseline_x2_1.onnx edsr_baseline_x2_1_sim.onnx --overwrite-input-shape=1,1,1080,1920
28
+ ```
29
+
30
+ ## 转换模型(ONNX -> Axera)
31
+ 使用模型转换工具 `Pulsar2` 将 ONNX 模型转换成适用于 Axera 的 NPU 运行的模型文件格式 `.axmodel`,通常情况下需要经过以下两个步骤:
32
+
33
+ - 生成适用于该模型的 PTQ 量化校准数据集
34
+ - 使用 `Pulsar2 build` 命令集进行模型转换(PTQ 量化、编译),更详细的使用说明请参考 [AXera Pulsar2 工具链指导手册](https://pulsar2-docs.readthedocs.io/zh-cn/latest/index.html)
35
+
36
+ ### 量化数据集
37
+ 准备量化图片若张,打包成Image.zip
38
+
39
+ ### 模型转换
40
+
41
+ #### 修改配置文件
42
+
43
+ 检查`config.json` 中 `calibration_dataset` 字段,将该字段配置的路径改为上一步下载的量化数据集存放路径
44
+
45
+ #### Pulsar2 build
46
+
47
+ 参考命令如下:
48
+
49
+ ```
50
+ pulsar2 build --input edsr_baseline_x2_1.onnx --config ./build_config_edsr.json --output_dir ./output --output_name edsr_baseline_x2_1.axmodel --target_hardware AX650 --compiler.check 0
51
+
52
+ 也可将参数写进json中,直接执行:
53
+ pulsar2 build --config ./build_config_edsr.json
54
+ ```
model_convert/axmodel/edsr_baseline_x2_1.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f20a4c7e058d68316e1324808e5ec67d6f1aa5dc1a95291fff82be0decf319f0
3
+ size 9129542
model_convert/axmodel/espcn_x2_T9.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73878c7caf32c58a0b2942721bc9296f9a01b02a548d3ea58120ddfc957a4a8a
3
+ size 1111469
model_convert/build_config_edsr.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input": "./edsr_baseline_x2_1.onnx",
3
+ "output_dir": "./output",
4
+ "output_name": "edsr_baseline_x2_1.axmodel",
5
+ "work_dir": "",
6
+ "model_type": "ONNX",
7
+ "target_hardware": "AX650",
8
+ "npu_mode": "NPU3",
9
+ "onnx_opt": {
10
+ "disable_onnx_optimization": false,
11
+ "model_check": false,
12
+ },
13
+ "quant": {
14
+ "input_configs": [
15
+ {
16
+ "tensor_name": "DEFAULT",
17
+ "calibration_dataset": "Image.zip",
18
+ "calibration_format": "Image",
19
+ "calibration_size": 10,
20
+ "calibration_mean": [0, 0, 0],
21
+ "calibration_std": [1.0, 1.0, 1.0]
22
+ }
23
+ ],
24
+ "calibration_method": "MinMax",
25
+ "precision_analysis": true,
26
+ "precision_analysis_method": "EndToEnd",
27
+ "precision_analysis_mode": "Reference"
28
+ },
29
+ "input_processors": [
30
+ {
31
+ "tensor_name": "DEFAULT",
32
+ "tensor_format": "AutoColorSpace",
33
+ "src_format": "AutoColorSpace",
34
+ "src_dtype": "FP32",
35
+ "csc_mode": "FullRange",
36
+ "csc_mat": [1.164, 0, 1.596, -222.912, 1.164, -0.392, -0.813, 135.616, 1.164, 2.017, 0, -276.8]
37
+ }
38
+ ],
39
+ "output_processors": [
40
+ {
41
+ "tensor_name": "DEFAULT"
42
+ }
43
+ ],
44
+ "compiler": {
45
+ "check": 0
46
+ }
47
+ }
48
+
model_convert/build_config_espcn.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input": "./espcn_x2_T9.onnx",
3
+ "output_dir": "./output",
4
+ "output_name": "espcn_x2_T9.axmodel",
5
+ "work_dir": "",
6
+ "model_type": "ONNX",
7
+ "target_hardware": "AX650",
8
+ "npu_mode": "NPU3",
9
+ "onnx_opt": {
10
+ "disable_onnx_optimization": false,
11
+ "model_check": false,
12
+ },
13
+ "quant": {
14
+ "input_configs": [
15
+ {
16
+ "tensor_name": "DEFAULT",
17
+ "calibration_dataset": "./npy.zip",
18
+ "calibration_format": "Numpy",
19
+ "calibration_size": 10,
20
+ "calibration_mean": [0],
21
+ "calibration_std": [1.0]
22
+ }
23
+ ],
24
+ "calibration_method": "MinMax",
25
+ "precision_analysis": true,
26
+ "precision_analysis_method": "EndToEnd",
27
+ "precision_analysis_mode": "Reference"
28
+ },
29
+ "input_processors": [
30
+ {
31
+ "tensor_name": "DEFAULT",
32
+ "tensor_format": "GRAY",
33
+ "src_format": "GRAY",
34
+ "src_dtype": "FP32",
35
+ }
36
+ ],
37
+ "output_processors": [
38
+ {
39
+ "tensor_name": "DEFAULT"
40
+ }
41
+ ],
42
+ "compiler": {
43
+ "check": 0
44
+ }
45
+ }
46
+
model_convert/onnx/edsr_baseline_x2_1.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b98049c5a5122cd394641159f8689e7d01e4ca3c4ed937b98b577712b0906099
3
+ size 5492581
model_convert/onnx/espcn_x2_T9.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be85e904e16bce24222cda0e72d1c168c99b8785aeae0029d4d5b94ebf771bdf
3
+ size 86307
python/common.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import skimage.color as sc
5
+
6
+ import torch
7
+
8
+ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
9
+ ih, iw = args[0].shape[:2]
10
+
11
+ if not input_large:
12
+ p = scale if multi else 1
13
+ tp = p * patch_size
14
+ ip = tp // scale
15
+ else:
16
+ tp = patch_size
17
+ ip = patch_size
18
+
19
+ ix = random.randrange(0, iw - ip + 1)
20
+ iy = random.randrange(0, ih - ip + 1)
21
+
22
+ if not input_large:
23
+ tx, ty = scale * ix, scale * iy
24
+ else:
25
+ tx, ty = ix, iy
26
+
27
+ ret = [
28
+ args[0][iy:iy + ip, ix:ix + ip, :],
29
+ *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
30
+ ]
31
+
32
+ return ret
33
+
34
+ def set_channel(*args, n_channels=3):
35
+ def _set_channel(img):
36
+ if img.ndim == 2:
37
+ img = np.expand_dims(img, axis=2)
38
+
39
+ c = img.shape[2]
40
+ if n_channels == 1 and c == 3:
41
+ img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
42
+ elif n_channels == 3 and c == 1:
43
+ img = np.concatenate([img] * n_channels, 2)
44
+
45
+ return img
46
+
47
+ return [_set_channel(a) for a in args]
48
+
49
+ def np_prepare(*args, rgb_range=255):
50
+ def _np_prepare(img):
51
+ img = np.ascontiguousarray(img.transpose((2, 0, 1)))
52
+ img = np.expand_dims(img, axis=0).astype(np.float32)
53
+ img /= 255 / rgb_range
54
+ return img
55
+
56
+ return [_np_prepare(a) for a in args]
57
+
58
+ def np2Tensor(*args, rgb_range=255):
59
+ def _np2Tensor(img):
60
+ np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
61
+ tensor = torch.from_numpy(np_transpose).float()
62
+ tensor.mul_(rgb_range / 255)
63
+
64
+ return tensor
65
+
66
+ return [_np2Tensor(a) for a in args]
67
+
68
+ def augment(*args, hflip=True, rot=True):
69
+ hflip = hflip and random.random() < 0.5
70
+ vflip = rot and random.random() < 0.5
71
+ rot90 = rot and random.random() < 0.5
72
+
73
+ def _augment(img):
74
+ if hflip: img = img[:, ::-1, :]
75
+ if vflip: img = img[::-1, :, :]
76
+ if rot90: img = img.transpose(1, 0, 2)
77
+
78
+ return img
79
+
80
+ return [_augment(a) for a in args]
81
+
python/imgproc.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import math
15
+ import random
16
+ from typing import Any, Tuple, List, Union
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+ from numpy import ndarray
22
+ from torch import Tensor
23
+ from torchvision.transforms import functional as F_vision
24
+
25
+ __all__ = [
26
+ "image_to_tensor", "tensor_to_image",
27
+ "image_resize", "preprocess_one_image",
28
+ "expand_y", "rgb_to_ycbcr", "bgr_to_ycbcr", "ycbcr_to_bgr", "ycbcr_to_rgb",
29
+ "rgb_to_ycbcr_torch", "bgr_to_ycbcr_torch",
30
+ "center_crop", "random_crop", "random_rotate", "random_vertically_flip", "random_horizontally_flip",
31
+ "center_crop_torch", "random_crop_torch", "random_rotate_torch", "random_vertically_flip_torch",
32
+ "random_horizontally_flip_torch",
33
+ ]
34
+
35
+
36
+ # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
37
+ def _cubic(x: Any) -> Any:
38
+ """Implementation of `cubic` function in Matlab under Python language.
39
+
40
+ Args:
41
+ x: Element vector.
42
+
43
+ Returns:
44
+ Bicubic interpolation
45
+
46
+ """
47
+ absx = torch.abs(x)
48
+ absx2 = absx ** 2
49
+ absx3 = absx ** 3
50
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
51
+ -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
52
+ ((absx > 1) * (absx <= 2)).type_as(absx))
53
+
54
+
55
+ # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
56
+ def _calculate_weights_indices(in_length: int,
57
+ out_length: int,
58
+ scale: float,
59
+ kernel_width: int,
60
+ antialiasing: bool) -> [np.ndarray, np.ndarray, int, int]:
61
+ """Implementation of `calculate_weights_indices` function in Matlab under Python language.
62
+
63
+ Args:
64
+ in_length (int): Input length.
65
+ out_length (int): Output length.
66
+ scale (float): Scale factor.
67
+ kernel_width (int): Kernel width.
68
+ antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
69
+ Caution: Bicubic down-sampling in PIL uses antialiasing by default.
70
+
71
+ Returns:
72
+ weights, indices, sym_len_s, sym_len_e
73
+
74
+ """
75
+ if (scale < 1) and antialiasing:
76
+ # Use a modified kernel (larger kernel width) to simultaneously
77
+ # interpolate and antialiasing
78
+ kernel_width = kernel_width / scale
79
+
80
+ # Output-space coordinates
81
+ x = torch.linspace(1, out_length, out_length)
82
+
83
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
84
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
85
+ # space maps to 1.5 in input space.
86
+ u = x / scale + 0.5 * (1 - 1 / scale)
87
+
88
+ # What is the left-most pixel that can be involved in the computation?
89
+ left = torch.floor(u - kernel_width / 2)
90
+
91
+ # What is the maximum number of pixels that can be involved in the
92
+ # computation? Note: it's OK to use an extra pixel here; if the
93
+ # corresponding weights are all zero, it will be eliminated at the end
94
+ # of this function.
95
+ p = math.ceil(kernel_width) + 2
96
+
97
+ # The indices of the input pixels involved in computing the k-th output
98
+ # pixel are in row k of the indices matrix.
99
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
100
+ out_length, p)
101
+
102
+ # The weights used to compute the k-th output pixel are in row k of the
103
+ # weights matrix.
104
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
105
+
106
+ # apply cubic kernel
107
+ if (scale < 1) and antialiasing:
108
+ weights = scale * _cubic(distance_to_center * scale)
109
+ else:
110
+ weights = _cubic(distance_to_center)
111
+
112
+ # Normalize the weights matrix so that each row sums to 1.
113
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
114
+ weights = weights / weights_sum.expand(out_length, p)
115
+
116
+ # If a column in weights is all zero, get rid of it. only consider the
117
+ # first and last column.
118
+ weights_zero_tmp = torch.sum((weights == 0), 0)
119
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
120
+ indices = indices.narrow(1, 1, p - 2)
121
+ weights = weights.narrow(1, 1, p - 2)
122
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
123
+ indices = indices.narrow(1, 0, p - 2)
124
+ weights = weights.narrow(1, 0, p - 2)
125
+ weights = weights.contiguous()
126
+ indices = indices.contiguous()
127
+ sym_len_s = -indices.min() + 1
128
+ sym_len_e = indices.max() - in_length
129
+ indices = indices + sym_len_s - 1
130
+ return weights, indices, int(sym_len_s), int(sym_len_e)
131
+
132
+
133
+ def image_to_tensor(image: ndarray, range_norm: bool, half: bool) -> Tensor:
134
+ """Convert the image data type to the Tensor (NCWH) data type supported by PyTorch
135
+
136
+ Args:
137
+ image (np.ndarray): The image data read by ``OpenCV.imread``, the data range is [0,255] or [0, 1]
138
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
139
+ half (bool): Whether to convert torch.float32 similarly to torch.half type
140
+
141
+ Returns:
142
+ tensor (Tensor): Data types supported by PyTorch
143
+
144
+ Examples:
145
+ >>> example_image = cv2.imread("lr_image.bmp")
146
+ >>> example_tensor = image_to_tensor(example_image, range_norm=True, half=False)
147
+
148
+ """
149
+ # Convert image data type to Tensor data type
150
+ tensor = F_vision.to_tensor(image)
151
+
152
+ # Scale the image data from [0, 1] to [-1, 1]
153
+ if range_norm:
154
+ tensor = tensor.mul(2.0).sub(1.0)
155
+
156
+ # Convert torch.float32 image data type to torch.half image data type
157
+ if half:
158
+ tensor = tensor.half()
159
+
160
+ return tensor
161
+
162
+
163
+ def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool) -> Any:
164
+ """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type
165
+
166
+ Args:
167
+ tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1]
168
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
169
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
170
+
171
+ Returns:
172
+ image (np.ndarray): Data types supported by PIL or OpenCV
173
+
174
+ Examples:
175
+ >>> example_image = cv2.imread("lr_image.bmp")
176
+ >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False)
177
+
178
+ """
179
+ if range_norm:
180
+ tensor = tensor.add(1.0).div(2.0)
181
+ if half:
182
+ tensor = tensor.half()
183
+
184
+ image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
185
+
186
+ return image
187
+
188
+ def array_to_image(array: ndarray) -> Any:
189
+ """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type
190
+
191
+ Args:
192
+ tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1]
193
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
194
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
195
+
196
+ Returns:
197
+ image (np.ndarray): Data types supported by PIL or OpenCV
198
+
199
+ Examples:
200
+ >>> example_image = cv2.imread("lr_image.bmp")
201
+ >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False)
202
+
203
+ """
204
+ image = np.clip(np.transpose(np.squeeze(array, axis=0), (1, 2, 0)) * 255, 0 ,255).astype(np.uint8)
205
+
206
+ return image
207
+
208
+ def preprocess_one_image(image_path: str, device: torch.device) -> [Tensor, ndarray, ndarray]:
209
+ image = cv2.imread(image_path).astype(np.float32) / 255.0
210
+
211
+ # BGR to YCbCr
212
+ ycbcr_image = bgr_to_ycbcr(image, only_use_y_channel=False)
213
+
214
+ # Split YCbCr image data
215
+ y_image, cb_image, cr_image = cv2.split(ycbcr_image)
216
+
217
+ # Convert image data to pytorch format data
218
+ y_tensor = image_to_tensor(y_image, False, False).unsqueeze_(0)
219
+
220
+ # Transfer tensor channel image format data to CUDA device
221
+ y_tensor = y_tensor.to(device=device, non_blocking=True)
222
+
223
+ return y_tensor, cb_image, cr_image
224
+
225
+ def preprocess_one_frame(image: ndarray) -> [ndarray, ndarray, ndarray]:
226
+ image = image.astype(np.float32) / 255.0
227
+
228
+ # BGR to YCbCr
229
+ ycbcr_image = bgr_to_ycbcr(image, only_use_y_channel=False)
230
+
231
+ # Split YCbCr image data
232
+ y_image, cb_image, cr_image = cv2.split(ycbcr_image)
233
+
234
+ # Convert image data to pytorch format data
235
+ y_image = y_image[np.newaxis, np.newaxis, ...]
236
+ #print(y_image.shape)
237
+
238
+ # Transfer tensor channel image format data to CUDA device
239
+ #y_tensor = y_tensor.to(device=device, non_blocking=True)
240
+
241
+ return y_image, cb_image, cr_image
242
+
243
+
244
+
245
+ # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
246
+ def image_resize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any:
247
+ """Implementation of `imresize` function in Matlab under Python language.
248
+
249
+ Args:
250
+ image: The input image.
251
+ scale_factor (float): Scale factor. The same scale applies for both height and width.
252
+ antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
253
+ Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``.
254
+
255
+ Returns:
256
+ out_2 (np.ndarray): Output image with shape (c, h, w), [0, 1] range, w/o round
257
+
258
+ """
259
+ squeeze_flag = False
260
+ if type(image).__module__ == np.__name__: # numpy type
261
+ numpy_type = True
262
+ if image.ndim == 2:
263
+ image = image[:, :, None]
264
+ squeeze_flag = True
265
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float()
266
+ else:
267
+ numpy_type = False
268
+ if image.ndim == 2:
269
+ image = image.unsqueeze(0)
270
+ squeeze_flag = True
271
+
272
+ in_c, in_h, in_w = image.size()
273
+ out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor)
274
+ kernel_width = 4
275
+
276
+ # get weights and indices
277
+ weights_h, indices_h, sym_len_hs, sym_len_he = _calculate_weights_indices(in_h, out_h, scale_factor, kernel_width,
278
+ antialiasing)
279
+ weights_w, indices_w, sym_len_ws, sym_len_we = _calculate_weights_indices(in_w, out_w, scale_factor, kernel_width,
280
+ antialiasing)
281
+ # process H dimension
282
+ # symmetric copying
283
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
284
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(image)
285
+
286
+ sym_patch = image[:, :sym_len_hs, :]
287
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
288
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
289
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
290
+
291
+ sym_patch = image[:, -sym_len_he:, :]
292
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
293
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
294
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
295
+
296
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
297
+ kernel_width = weights_h.size(1)
298
+ for i in range(out_h):
299
+ idx = int(indices_h[i][0])
300
+ for j in range(in_c):
301
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
302
+
303
+ # process W dimension
304
+ # symmetric copying
305
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
306
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
307
+
308
+ sym_patch = out_1[:, :, :sym_len_ws]
309
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
310
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
311
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
312
+
313
+ sym_patch = out_1[:, :, -sym_len_we:]
314
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
315
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
316
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
317
+
318
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
319
+ kernel_width = weights_w.size(1)
320
+ for i in range(out_w):
321
+ idx = int(indices_w[i][0])
322
+ for j in range(in_c):
323
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
324
+
325
+ if squeeze_flag:
326
+ out_2 = out_2.squeeze(0)
327
+ if numpy_type:
328
+ out_2 = out_2.numpy()
329
+ if not squeeze_flag:
330
+ out_2 = out_2.transpose(1, 2, 0)
331
+
332
+ return out_2
333
+
334
+
335
+ def expand_y(image: np.ndarray) -> np.ndarray:
336
+ """Convert BGR channel to YCbCr format,
337
+ and expand Y channel data in YCbCr, from HW to HWC
338
+
339
+ Args:
340
+ image (np.ndarray): Y channel image data
341
+
342
+ Returns:
343
+ y_image (np.ndarray): Y-channel image data in HWC form
344
+
345
+ """
346
+ # Normalize image data to [0, 1]
347
+ image = image.astype(np.float32) / 255.
348
+
349
+ # Convert BGR to YCbCr, and extract only Y channel
350
+ y_image = bgr_to_ycbcr(image, only_use_y_channel=True)
351
+
352
+ # Expand Y channel
353
+ y_image = y_image[..., None]
354
+
355
+ # Normalize the image data to [0, 255]
356
+ y_image = y_image.astype(np.float64) * 255.0
357
+
358
+ return y_image
359
+
360
+
361
+ def rgb_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
362
+ """Implementation of rgb2ycbcr function in Matlab under Python language
363
+
364
+ Args:
365
+ image (np.ndarray): Image input in RGB format.
366
+ only_use_y_channel (bool): Extract Y channel separately
367
+
368
+ Returns:
369
+ image (np.ndarray): YCbCr image array data
370
+
371
+ """
372
+ if only_use_y_channel:
373
+ image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0
374
+ else:
375
+ image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [
376
+ 16, 128, 128]
377
+
378
+ image /= 255.
379
+ image = image.astype(np.float32)
380
+
381
+ return image
382
+
383
+
384
+ def bgr_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
385
+ """Implementation of bgr2ycbcr function in Matlab under Python language.
386
+
387
+ Args:
388
+ image (np.ndarray): Image input in BGR format
389
+ only_use_y_channel (bool): Extract Y channel separately
390
+
391
+ Returns:
392
+ image (np.ndarray): YCbCr image array data
393
+
394
+ """
395
+ if only_use_y_channel:
396
+ image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0
397
+ else:
398
+ image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [
399
+ 16, 128, 128]
400
+
401
+ image /= 255.
402
+ image = image.astype(np.float32)
403
+
404
+ return image
405
+
406
+
407
+ def ycbcr_to_rgb(image: np.ndarray) -> np.ndarray:
408
+ """Implementation of ycbcr2rgb function in Matlab under Python language.
409
+
410
+ Args:
411
+ image (np.ndarray): Image input in YCbCr format.
412
+
413
+ Returns:
414
+ image (np.ndarray): RGB image array data
415
+
416
+ """
417
+ image_dtype = image.dtype
418
+ image *= 255.
419
+
420
+ image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
421
+ [0, -0.00153632, 0.00791071],
422
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
423
+
424
+ image /= 255.
425
+ image = image.astype(image_dtype)
426
+
427
+ return image
428
+
429
+
430
+ def ycbcr_to_bgr(image: np.ndarray) -> np.ndarray:
431
+ """Implementation of ycbcr2bgr function in Matlab under Python language.
432
+
433
+ Args:
434
+ image (np.ndarray): Image input in YCbCr format.
435
+
436
+ Returns:
437
+ image (np.ndarray): BGR image array data
438
+
439
+ """
440
+ image_dtype = image.dtype
441
+ image *= 255.
442
+
443
+ image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
444
+ [0.00791071, -0.00153632, 0],
445
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]
446
+
447
+ image /= 255.
448
+ image = image.astype(image_dtype)
449
+
450
+ return image
451
+
452
+
453
+ def rgb_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
454
+ """Implementation of rgb2ycbcr function in Matlab under PyTorch
455
+
456
+ References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
457
+
458
+ Args:
459
+ tensor (Tensor): Image data in PyTorch format
460
+ only_use_y_channel (bool): Extract only Y channel
461
+
462
+ Returns:
463
+ tensor (Tensor): YCbCr image data in PyTorch format
464
+
465
+ """
466
+ if only_use_y_channel:
467
+ weight = Tensor([[65.481], [128.553], [24.966]]).to(tensor)
468
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
469
+ else:
470
+ weight = Tensor([[65.481, -37.797, 112.0],
471
+ [128.553, -74.203, -93.786],
472
+ [24.966, 112.0, -18.214]]).to(tensor)
473
+ bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
474
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
475
+
476
+ tensor /= 255.
477
+
478
+ return tensor
479
+
480
+
481
+ def bgr_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
482
+ """Implementation of bgr2ycbcr function in Matlab under PyTorch
483
+
484
+ References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
485
+
486
+ Args:
487
+ tensor (Tensor): Image data in PyTorch format
488
+ only_use_y_channel (bool): Extract only Y channel
489
+
490
+ Returns:
491
+ tensor (Tensor): YCbCr image data in PyTorch format
492
+
493
+ """
494
+ if only_use_y_channel:
495
+ weight = Tensor([[24.966], [128.553], [65.481]]).to(tensor)
496
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
497
+ else:
498
+ weight = Tensor([[24.966, 112.0, -18.214],
499
+ [128.553, -74.203, -93.786],
500
+ [65.481, -37.797, 112.0]]).to(tensor)
501
+ bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
502
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
503
+
504
+ tensor /= 255.
505
+
506
+ return tensor
507
+
508
+
509
+ def center_crop(image: np.ndarray, image_size: int) -> np.ndarray:
510
+ """Crop small image patches from one image center area.
511
+
512
+ Args:
513
+ image (np.ndarray): The input image for `OpenCV.imread`.
514
+ image_size (int): The size of the captured image area.
515
+
516
+ Returns:
517
+ patch_image (np.ndarray): Small patch image
518
+
519
+ """
520
+ image_height, image_width = image.shape[:2]
521
+
522
+ # Just need to find the top and left coordinates of the image
523
+ top = (image_height - image_size) // 2
524
+ left = (image_width - image_size) // 2
525
+
526
+ # Crop image patch
527
+ patch_image = image[top:top + image_size, left:left + image_size, ...]
528
+
529
+ return patch_image
530
+
531
+
532
+ def random_crop(image: np.ndarray, image_size: int) -> np.ndarray:
533
+ """Crop small image patches from one image.
534
+
535
+ Args:
536
+ image (np.ndarray): The input image for `OpenCV.imread`.
537
+ image_size (int): The size of the captured image area.
538
+
539
+ Returns:
540
+ patch_image (np.ndarray): Small patch image
541
+
542
+ """
543
+ image_height, image_width = image.shape[:2]
544
+
545
+ # Just need to find the top and left coordinates of the image
546
+ top = random.randint(0, image_height - image_size)
547
+ left = random.randint(0, image_width - image_size)
548
+
549
+ # Crop image patch
550
+ patch_image = image[top:top + image_size, left:left + image_size, ...]
551
+
552
+ return patch_image
553
+
554
+
555
+ def random_rotate(image,
556
+ angles: list,
557
+ center: Tuple[int, int] = None,
558
+ scale_factor: float = 1.0) -> np.ndarray:
559
+ """Rotate an image by a random angle
560
+
561
+ Args:
562
+ image (np.ndarray): Image read with OpenCV
563
+ angles (list): Rotation angle range
564
+ center (optional, tuple[int, int]): High resolution image selection center point. Default: ``None``
565
+ scale_factor (optional, float): scaling factor. Default: 1.0
566
+
567
+ Returns:
568
+ rotated_image (np.ndarray): image after rotation
569
+
570
+ """
571
+ image_height, image_width = image.shape[:2]
572
+
573
+ if center is None:
574
+ center = (image_width // 2, image_height // 2)
575
+
576
+ # Random select specific angle
577
+ angle = random.choice(angles)
578
+ matrix = cv2.getRotationMatrix2D(center, angle, scale_factor)
579
+ rotated_image = cv2.warpAffine(image, matrix, (image_width, image_height))
580
+
581
+ return rotated_image
582
+
583
+
584
+ def random_horizontally_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
585
+ """Flip the image upside down randomly
586
+
587
+ Args:
588
+ image (np.ndarray): Image read with OpenCV
589
+ p (optional, float): Horizontally flip probability. Default: 0.5
590
+
591
+ Returns:
592
+ horizontally_flip_image (np.ndarray): image after horizontally flip
593
+
594
+ """
595
+ if random.random() < p:
596
+ horizontally_flip_image = cv2.flip(image, 1)
597
+ else:
598
+ horizontally_flip_image = image
599
+
600
+ return horizontally_flip_image
601
+
602
+
603
+ def random_vertically_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
604
+ """Flip an image horizontally randomly
605
+
606
+ Args:
607
+ image (np.ndarray): Image read with OpenCV
608
+ p (optional, float): Vertically flip probability. Default: 0.5
609
+
610
+ Returns:
611
+ vertically_flip_image (np.ndarray): image after vertically flip
612
+
613
+ """
614
+ if random.random() < p:
615
+ vertically_flip_image = cv2.flip(image, 0)
616
+ else:
617
+ vertically_flip_image = image
618
+
619
+ return vertically_flip_image
620
+
621
+
622
+ def center_crop_torch(
623
+ gt_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
624
+ lr_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
625
+ gt_patch_size: int,
626
+ upscale_factor: int,
627
+ ) -> Union[
628
+ Tuple[ndarray, ndarray],
629
+ Tuple[Tensor, Tensor],
630
+ Tuple[List[ndarray], List[ndarray]],
631
+ Tuple[List[Tensor], List[Tensor]]
632
+ ]:
633
+ if not isinstance(gt_images, list):
634
+ gt_images = [gt_images]
635
+ if not isinstance(lr_images, list):
636
+ lr_images = [lr_images]
637
+
638
+ # Detect input image data type
639
+ input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy"
640
+
641
+ if input_type == "Tensor":
642
+ lr_image_height, lr_image_width = lr_images[0].size()[-2:]
643
+ else:
644
+ lr_image_height, lr_image_width = lr_images[0].shape[0:2]
645
+
646
+ # Compute low-resolution image patch size
647
+ lr_patch_size = gt_patch_size // upscale_factor
648
+
649
+ # Calculate the start indices of the crop
650
+ lr_top = (lr_image_height - lr_patch_size) // 2
651
+ lr_left = (lr_image_width - lr_patch_size) // 2
652
+
653
+ # Crop lr image patch
654
+ if input_type == "Tensor":
655
+ lr_images = [lr_image[
656
+ :,
657
+ :,
658
+ lr_top:lr_top + lr_patch_size,
659
+ lr_left:lr_left + lr_patch_size] for lr_image in lr_images]
660
+ else:
661
+ lr_images = [lr_image[
662
+ lr_top:lr_top + lr_patch_size,
663
+ lr_left:lr_left + lr_patch_size,
664
+ ...] for lr_image in lr_images]
665
+
666
+ # Crop gt image patch
667
+ gt_top, gt_left = int(lr_top * upscale_factor), int(lr_left * upscale_factor)
668
+
669
+ if input_type == "Tensor":
670
+ gt_images = [v[
671
+ :,
672
+ :,
673
+ gt_top:gt_top + gt_patch_size,
674
+ gt_left:gt_left + gt_patch_size] for v in gt_images]
675
+ else:
676
+ gt_images = [v[
677
+ gt_top:gt_top + gt_patch_size,
678
+ gt_left:gt_left + gt_patch_size,
679
+ ...] for v in gt_images]
680
+
681
+ # When image number is 1
682
+ if len(gt_images) == 1:
683
+ gt_images = gt_images[0]
684
+ if len(lr_images) == 1:
685
+ lr_images = lr_images[0]
686
+
687
+ return gt_images, lr_images
688
+
689
+
690
+ # def random_crop_torch(
691
+ # gt_images: ndarray | Tensor | list[ndarray] | list[Tensor],
692
+ # lr_images: ndarray | Tensor | list[ndarray] | list[Tensor],
693
+ # gt_patch_size: int,
694
+ # upscale_factor: int,
695
+ # ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]:
696
+
697
+
698
+ def random_crop_torch(
699
+ gt_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
700
+ lr_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
701
+ gt_patch_size: int,
702
+ upscale_factor: int,
703
+ ) -> Union[
704
+ Tuple[ndarray, ndarray],
705
+ Tuple[Tensor, Tensor],
706
+ Tuple[List[ndarray], List[ndarray]],
707
+ Tuple[List[Tensor], List[Tensor]]
708
+ ]:
709
+ if not isinstance(gt_images, list):
710
+ gt_images = [gt_images]
711
+ if not isinstance(lr_images, list):
712
+ lr_images = [lr_images]
713
+
714
+ # Detect input image data type
715
+ input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy"
716
+
717
+ if input_type == "Tensor":
718
+ lr_image_height, lr_image_width = lr_images[0].size()[-2:]
719
+ else:
720
+ lr_image_height, lr_image_width = lr_images[0].shape[0:2]
721
+
722
+ # Compute low-resolution image patch size
723
+ lr_patch_size = gt_patch_size // upscale_factor
724
+
725
+ # Just need to find the top and left coordinates of the image
726
+ lr_top = random.randint(0, lr_image_height - lr_patch_size)
727
+ lr_left = random.randint(0, lr_image_width - lr_patch_size)
728
+
729
+ # Crop lr image patch
730
+ if input_type == "Tensor":
731
+ lr_images = [lr_image[
732
+ :,
733
+ :,
734
+ lr_top:lr_top + lr_patch_size,
735
+ lr_left:lr_left + lr_patch_size] for lr_image in lr_images]
736
+ else:
737
+ lr_images = [lr_image[
738
+ lr_top:lr_top + lr_patch_size,
739
+ lr_left:lr_left + lr_patch_size,
740
+ ...] for lr_image in lr_images]
741
+
742
+ # Crop gt image patch
743
+ gt_top, gt_left = int(lr_top * upscale_factor), int(lr_left * upscale_factor)
744
+
745
+ if input_type == "Tensor":
746
+ gt_images = [v[
747
+ :,
748
+ :,
749
+ gt_top:gt_top + gt_patch_size,
750
+ gt_left:gt_left + gt_patch_size] for v in gt_images]
751
+ else:
752
+ gt_images = [v[
753
+ gt_top:gt_top + gt_patch_size,
754
+ gt_left:gt_left + gt_patch_size,
755
+ ...] for v in gt_images]
756
+
757
+ # When image number is 1
758
+ if len(gt_images) == 1:
759
+ gt_images = gt_images[0]
760
+ if len(lr_images) == 1:
761
+ lr_images = lr_images[0]
762
+
763
+ return gt_images, lr_images
764
+
765
+
766
+ # def random_rotate_torch(
767
+ # gt_images: ndarray | Tensor | list[ndarray] | list[Tensor],
768
+ # lr_images: ndarray | Tensor | list[ndarray] | list[Tensor],
769
+ # upscale_factor: int,
770
+ # angles: list,
771
+ # gt_center: tuple = None,
772
+ # lr_center: tuple = None,
773
+ # rotate_scale_factor: float = 1.0
774
+ # ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]:
775
+
776
+ def random_rotate_torch(
777
+ gt_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
778
+ lr_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
779
+ upscale_factor: int,
780
+ angles: list,
781
+ gt_center: tuple = None,
782
+ lr_center: tuple = None,
783
+ rotate_scale_factor: float = 1.0
784
+ )-> Union[
785
+ Tuple[ndarray, ndarray],
786
+ Tuple[Tensor, Tensor],
787
+ Tuple[List[ndarray], List[ndarray]],
788
+ Tuple[List[Tensor], List[Tensor]]
789
+ ]:
790
+ # Random select specific angle
791
+ angle = random.choice(angles)
792
+
793
+ if not isinstance(gt_images, list):
794
+ gt_images = [gt_images]
795
+ if not isinstance(lr_images, list):
796
+ lr_images = [lr_images]
797
+
798
+ # Detect input image data type
799
+ input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy"
800
+
801
+ if input_type == "Tensor":
802
+ lr_image_height, lr_image_width = lr_images[0].size()[-2:]
803
+ else:
804
+ lr_image_height, lr_image_width = lr_images[0].shape[0:2]
805
+
806
+ # Rotate LR image
807
+ if lr_center is None:
808
+ lr_center = [lr_image_width // 2, lr_image_height // 2]
809
+
810
+ lr_matrix = cv2.getRotationMatrix2D(lr_center, angle, rotate_scale_factor)
811
+
812
+ if input_type == "Tensor":
813
+ lr_images = [F_vision.rotate(lr_image, angle, center=lr_center) for lr_image in lr_images]
814
+ else:
815
+ lr_images = [cv2.warpAffine(lr_image, lr_matrix, (lr_image_width, lr_image_height)) for lr_image in lr_images]
816
+
817
+ # Rotate GT image
818
+ gt_image_width = int(lr_image_width * upscale_factor)
819
+ gt_image_height = int(lr_image_height * upscale_factor)
820
+
821
+ if gt_center is None:
822
+ gt_center = [gt_image_width // 2, gt_image_height // 2]
823
+
824
+ gt_matrix = cv2.getRotationMatrix2D(gt_center, angle, rotate_scale_factor)
825
+
826
+ if input_type == "Tensor":
827
+ gt_images = [F_vision.rotate(gt_image, angle, center=gt_center) for gt_image in gt_images]
828
+ else:
829
+ gt_images = [cv2.warpAffine(gt_image, gt_matrix, (gt_image_width, gt_image_height)) for gt_image in gt_images]
830
+
831
+ # When image number is 1
832
+ if len(gt_images) == 1:
833
+ gt_images = gt_images[0]
834
+ if len(lr_images) == 1:
835
+ lr_images = lr_images[0]
836
+
837
+ return gt_images, lr_images
838
+
839
+
840
+ # def random_horizontally_flip_torch(
841
+ # gt_images: ndarray | Tensor | list[ndarray] | list[Tensor],
842
+ # lr_images: ndarray | Tensor | list[ndarray] | list[Tensor],
843
+ # p: float = 0.5
844
+ # ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]:
845
+
846
+ def random_horizontally_flip_torch(
847
+ gt_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
848
+ lr_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
849
+ p: float = 0.5
850
+ )-> Union[
851
+ Tuple[ndarray, ndarray],
852
+ Tuple[Tensor, Tensor],
853
+ Tuple[List[ndarray], List[ndarray]],
854
+ Tuple[List[Tensor], List[Tensor]]
855
+ ]:
856
+
857
+ # Get horizontal flip probability
858
+ flip_prob = random.random()
859
+
860
+ if not isinstance(gt_images, list):
861
+ gt_images = [gt_images]
862
+ if not isinstance(lr_images, list):
863
+ lr_images = [lr_images]
864
+
865
+ # Detect input image data type
866
+ input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy"
867
+
868
+ if flip_prob > p:
869
+ if input_type == "Tensor":
870
+ lr_images = [F_vision.hflip(lr_image) for lr_image in lr_images]
871
+ gt_images = [F_vision.hflip(gt_image) for gt_image in gt_images]
872
+ else:
873
+ lr_images = [cv2.flip(lr_image, 1) for lr_image in lr_images]
874
+ gt_images = [cv2.flip(gt_image, 1) for gt_image in gt_images]
875
+
876
+ # When image number is 1
877
+ if len(gt_images) == 1:
878
+ gt_images = gt_images[0]
879
+ if len(lr_images) == 1:
880
+ lr_images = lr_images[0]
881
+
882
+ return gt_images, lr_images
883
+
884
+
885
+ # def random_vertically_flip_torch(
886
+ # gt_images: ndarray | Tensor | list[ndarray] | list[Tensor],
887
+ # lr_images: ndarray | Tensor | list[ndarray] | list[Tensor],
888
+ # p: float = 0.5
889
+ # ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]:
890
+ def random_vertically_flip_torch(
891
+ gt_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
892
+ lr_images: Union[ndarray, Tensor, List[ndarray], List[Tensor]],
893
+ p: float = 0.5
894
+ )-> Union[
895
+ Tuple[ndarray, ndarray],
896
+ Tuple[Tensor, Tensor],
897
+ Tuple[List[ndarray], List[ndarray]],
898
+ Tuple[List[Tensor], List[Tensor]]
899
+ ]:
900
+
901
+ # Get vertical flip probability
902
+ flip_prob = random.random()
903
+
904
+ if not isinstance(gt_images, list):
905
+ gt_images = [gt_images]
906
+ if not isinstance(lr_images, list):
907
+ lr_images = [lr_images]
908
+
909
+ # Detect input image data type
910
+ input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy"
911
+
912
+ if flip_prob > p:
913
+ if input_type == "Tensor":
914
+ lr_images = [F_vision.vflip(lr_image) for lr_image in lr_images]
915
+ gt_images = [F_vision.vflip(gt_image) for gt_image in gt_images]
916
+ else:
917
+ lr_images = [cv2.flip(lr_image, 0) for lr_image in lr_images]
918
+ gt_images = [cv2.flip(gt_image, 0) for gt_image in gt_images]
919
+
920
+ # When image number is 1
921
+ if len(gt_images) == 1:
922
+ gt_images = gt_images[0]
923
+ if len(lr_images) == 1:
924
+ lr_images = lr_images[0]
925
+
926
+ return gt_images, lr_images
python/run_axmodel.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ import common
10
+ import imgproc
11
+ import axengine as axe
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--model", type=str, default="edsr_baseline_x2_1.axmodel", help="axmodel model path")
15
+ parser.add_argument('--scale', nargs='+', type=int, default=[2], help='super resolution scale')
16
+ parser.add_argument("--dir_demo", type=str, default='../video/test_1920x1080.mp4', help="demo image directory")
17
+ parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB')
18
+
19
+ def quantize(img, rgb_range):
20
+ pixel_range = 255 / rgb_range
21
+ return np.round(np.clip(img * pixel_range, 0, 255)) / pixel_range
22
+
23
+ def from_numpy(x):
24
+ return x if isinstance(x, np.ndarray) else np.array(x)
25
+
26
+ class VideoTester():
27
+ def __init__(self, scale, my_model, dir_demo, rgb_range=255, cuda=True, arch='EDSR'):
28
+ self.scale = scale
29
+ self.rgb_range = rgb_range
30
+ self.session = axe.InferenceSession(my_model, 'AxEngineExecutionProvider')
31
+ self.output_names = [x.name for x in self.session.get_outputs()]
32
+ self.input_name = self.session.get_inputs()[0].name
33
+ self.dir_demo = dir_demo
34
+ self.filename, _ = os.path.splitext(os.path.basename(dir_demo))
35
+ self.arch = arch
36
+
37
+ def test(self):
38
+ torch.set_grad_enabled(False)
39
+ if not os.path.exists('experiment'):
40
+ os.makedirs('experiment')
41
+ for idx_scale, scale in enumerate(self.scale):
42
+ vidcap = cv2.VideoCapture(self.dir_demo)
43
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
44
+
45
+ vidwri = cv2.VideoWriter(
46
+ os.path.join('experiment', ('{}_x{}.avi'.format(self.filename, scale))),
47
+ cv2.VideoWriter_fourcc(*'XVID'),
48
+ vidcap.get(cv2.CAP_PROP_FPS),
49
+ (
50
+ int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
51
+ int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
52
+ )
53
+ )
54
+
55
+ total_times = 0
56
+ tqdm_test = tqdm(range(total_frames), ncols=80)
57
+
58
+ if self.arch == 'EDSR':
59
+ for _ in tqdm_test:
60
+ success, lr = vidcap.read()
61
+ if not success: break
62
+ start_time = time.time()
63
+ lr_y_image, = common.set_channel(lr, n_channels=3)
64
+ lr_y_image, = common.np_prepare(lr_y_image, rgb_range=self.rgb_range)
65
+
66
+ sr = self.session.run(self.output_names, {self.input_name: lr_y_image})
67
+ end_time = time.time()
68
+ total_times += end_time - start_time
69
+
70
+ if isinstance(sr, (list, tuple)):
71
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
72
+ else:
73
+ sr = from_numpy(sr)
74
+
75
+ sr = quantize(sr, self.rgb_range).squeeze(0)
76
+ normalized = sr * 255 / self.rgb_range
77
+ ndarr = normalized.transpose(1, 2, 0).astype(np.uint8)
78
+ vidwri.write(ndarr)
79
+
80
+ elif self.arch == 'ESPCN':
81
+ for _ in tqdm_test:
82
+ success, lr = vidcap.read()
83
+ if not success: break
84
+ start_time = time.time()
85
+
86
+ lr_y_image, lr_cb_image, lr_cr_image = imgproc.preprocess_one_frame(lr)
87
+ bic_cb_image = cv2.resize(lr_cb_image,
88
+ (int(lr_cb_image.shape[1] * scale),
89
+ int(lr_cb_image.shape[0] * scale)),
90
+ interpolation=cv2.INTER_CUBIC)
91
+ bic_cr_image = cv2.resize(lr_cr_image,
92
+ (int(lr_cr_image.shape[1] * scale),
93
+ int(lr_cr_image.shape[0] * scale)),
94
+ interpolation=cv2.INTER_CUBIC)
95
+
96
+ sr = self.session.run(self.output_names, {self.input_name: lr_y_image})
97
+ end_time = time.time()
98
+ total_times += end_time - start_time
99
+
100
+ if isinstance(sr, (list, tuple)):
101
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
102
+ else:
103
+ sr = from_numpy(sr)
104
+
105
+ ndarr = imgproc.array_to_image(sr)
106
+ sr_y_image = ndarr.astype(np.float32) / 255.0
107
+ sr_ycbcr_image = cv2.merge([sr_y_image[:, :, 0], bic_cb_image, bic_cr_image])
108
+ sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image)
109
+ sr_image = np.clip(sr_image* 255.0, 0 , 255).astype(np.uint8)
110
+ vidwri.write(sr_image)
111
+
112
+ print('Total time: {:.3f} seconds for {} frames'.format(total_times, total_frames))
113
+ print('Average time: {:.3f} seconds for each frame'.format(total_times / total_frames))
114
+
115
+ vidcap.release()
116
+ vidwri.release()
117
+
118
+ torch.set_grad_enabled(True)
119
+
120
+ def main():
121
+ args = parser.parse_args()
122
+ t = VideoTester(args.scale, args.model, args.dir_demo, arch='EDSR')
123
+ t.test()
124
+
125
+ if __name__ == '__main__':
126
+ main()
python/run_onnx.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ import common
10
+ import imgproc
11
+ import onnxruntime as ort
12
+
13
+ torch.manual_seed(1)
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--model", type=str, default="edsr_baseline_x2_1.onnx", help="onnx model path")
17
+ parser.add_argument('--scale', nargs='+', type=int, default=[2], help='super resolution scale')
18
+ parser.add_argument("--dir_demo", type=str, default='../video/test_1920x1080.mp4', help="demo image directory")
19
+ parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB')
20
+
21
+ def quantize(img, rgb_range):
22
+ pixel_range = 255 / rgb_range
23
+ return np.round(np.clip(img * pixel_range, 0, 255)) / pixel_range
24
+
25
+ def from_numpy(x):
26
+ return x if isinstance(x, np.ndarray) else np.array(x)
27
+
28
+ class VideoTester():
29
+ def __init__(self, scale, my_model, dir_demo, rgb_range=255, cuda=True, arch='EDSR'):
30
+ self.scale = scale
31
+ self.rgb_range = rgb_range
32
+ self.providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider']
33
+ self.session = ort.InferenceSession(my_model, providers=self.providers)
34
+ self.output_names = [x.name for x in self.session.get_outputs()]
35
+ self.input_name = self.session.get_inputs()[0].name
36
+ self.dir_demo = dir_demo
37
+ self.filename, _ = os.path.splitext(os.path.basename(dir_demo))
38
+ self.arch = arch
39
+
40
+ def test(self):
41
+ torch.set_grad_enabled(False)
42
+ if not os.path.exists('experiment'):
43
+ os.makedirs('experiment')
44
+ for idx_scale, scale in enumerate(self.scale):
45
+ vidcap = cv2.VideoCapture(self.dir_demo)
46
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
47
+
48
+ vidwri = cv2.VideoWriter(
49
+ os.path.join('experiment', ('{}_x{}.avi'.format(self.filename, scale))),
50
+ cv2.VideoWriter_fourcc(*'XVID'),
51
+ vidcap.get(cv2.CAP_PROP_FPS),
52
+ (
53
+ int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
54
+ int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
55
+ )
56
+ )
57
+
58
+ total_times = 0
59
+ tqdm_test = tqdm(range(total_frames), ncols=80)
60
+
61
+ if self.arch == 'EDSR':
62
+ for _ in tqdm_test:
63
+ success, lr = vidcap.read()
64
+ if not success: break
65
+ start_time = time.time()
66
+ lr_y_image, = common.set_channel(lr, n_channels=3)
67
+ lr_y_image, = common.np_prepare(lr_y_image, rgb_range=self.rgb_range)
68
+
69
+ sr = self.session.run(self.output_names, {self.input_name: lr_y_image})
70
+ end_time = time.time()
71
+ total_times += end_time - start_time
72
+
73
+ if isinstance(sr, (list, tuple)):
74
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
75
+ else:
76
+ sr = from_numpy(sr)
77
+
78
+ sr = quantize(sr, self.rgb_range).squeeze(0)
79
+ normalized = sr * 255 / self.rgb_range
80
+ ndarr = normalized.transpose(1, 2, 0).astype(np.uint8)
81
+ vidwri.write(ndarr)
82
+
83
+ elif self.arch == 'ESPCN':
84
+ for _ in tqdm_test:
85
+ success, lr = vidcap.read()
86
+ if not success: break
87
+ start_time = time.time()
88
+
89
+ lr_y_image, lr_cb_image, lr_cr_image = imgproc.preprocess_one_frame(lr)
90
+ bic_cb_image = cv2.resize(lr_cb_image,
91
+ (int(lr_cb_image.shape[1] * scale),
92
+ int(lr_cb_image.shape[0] * scale)),
93
+ interpolation=cv2.INTER_CUBIC)
94
+ bic_cr_image = cv2.resize(lr_cr_image,
95
+ (int(lr_cr_image.shape[1] * scale),
96
+ int(lr_cr_image.shape[0] * scale)),
97
+ interpolation=cv2.INTER_CUBIC)
98
+
99
+ sr = self.session.run(self.output_names, {self.input_name: lr_y_image})
100
+ end_time = time.time()
101
+ total_times += end_time - start_time
102
+
103
+ if isinstance(sr, (list, tuple)):
104
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
105
+ else:
106
+ sr = from_numpy(sr)
107
+
108
+ ndarr = imgproc.array_to_image(sr)
109
+ sr_y_image = ndarr.astype(np.float32) / 255.0
110
+ sr_ycbcr_image = cv2.merge([sr_y_image[:, :, 0], bic_cb_image, bic_cr_image])
111
+ sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image)
112
+ sr_image = np.clip(sr_image* 255.0, 0 , 255).astype(np.uint8)
113
+ vidwri.write(sr_image)
114
+
115
+ print('Total time: {:.3f} seconds for {} frames'.format(total_times, total_frames))
116
+ print('Average time: {:.3f} seconds for each frame'.format(total_times / total_frames))
117
+
118
+ vidcap.release()
119
+ vidwri.release()
120
+
121
+ torch.set_grad_enabled(True)
122
+
123
+ def main():
124
+ args = parser.parse_args()
125
+ t = VideoTester(args.scale, args.model, args.dir_demo, arch='EDSR')
126
+ t.test()
127
+
128
+ if __name__ == '__main__':
129
+ main()
video/1.png ADDED

Git LFS Details

  • SHA256: cf83de909a57a46ff437b6e393f498232dfbf06185ab3686bd5f075a4670b827
  • Pointer size: 132 Bytes
  • Size of remote file: 1.63 MB
video/2.png ADDED

Git LFS Details

  • SHA256: 8e7ef61a91df68b100329aebe347ee83f689b2574d89e194fa534f078baf0d73
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
video/test_1920x1080.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1033173661f71a07bf453a05ea5c9cdffbdedb68f990d0148c194bf5d3955b9
3
+ size 5930471