File size: 5,419 Bytes
b3357d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import sys
import cv2  
import numpy as np
import time

def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray:
    if len(im.shape) < 3:
        im = im[:, :, np.newaxis]
    im_np = np.transpose(im, (2, 0, 1)).astype(np.float32)
    im_np = np.expand_dims(im_np, axis=0)
    
    _, C, H_ori, W_ori = im_np.shape
    H_target, W_target = model_input_size
    
    x_target = np.linspace(0, W_ori - 1, W_target)
    y_target = np.linspace(0, H_ori - 1, H_target)
    xx_target, yy_target = np.meshgrid(x_target, y_target)
    
    x0 = np.floor(xx_target).astype(np.int32)
    x1 = np.minimum(x0 + 1, W_ori - 1)
    y0 = np.floor(yy_target).astype(np.int32)
    y1 = np.minimum(y0 + 1, H_ori - 1)
    
    wx0 = xx_target - x0
    wx1 = 1 - wx0
    wy0 = yy_target - y0
    wy1 = 1 - wy0
    
    im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32)
    for c in range(C):
        channel_data = im_np[0, c, :, :]
        top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
        bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
        im_interp[0, c, :, :] = wy1 * top + wy0 * bottom
    
    image = (im_interp / 1.0).astype(np.uint8)
    
    return image

def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray:
    result_np = np.squeeze(result, axis=0)
    C, H_ori, W_ori = result_np.shape
    H_target, W_target = im_size  # 目标尺寸(H,W)

    x_target = np.linspace(0, W_ori - 1, W_target)
    y_target = np.linspace(0, H_ori - 1, H_target)
    xx_target, yy_target = np.meshgrid(x_target, y_target)

    x0 = np.floor(xx_target).astype(np.int32)
    x1 = np.minimum(x0 + 1, W_ori - 1)
    y0 = np.floor(yy_target).astype(np.int32)
    y1 = np.minimum(y0 + 1, H_ori - 1)

    wx0 = xx_target - x0
    wx1 = 1 - wx0
    wy0 = yy_target - y0
    wy1 = 1 - wy0

    result_interp = np.zeros((C, H_target, W_target), dtype=np.float32)
    for c in range(C):
        channel_data = result_np[c, :, :]
        top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
        bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
        result_interp[c, :, :] = wy1 * top + wy0 * bottom
    
    ma = np.max(result_interp)
    mi = np.min(result_interp)

    result_norm = (result_interp - mi) / (ma - mi + 1e-8)  # 加极小值避免除零
    result_scaled = result_norm * 255
    im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8)
    im_array = np.squeeze(im_array)
    return im_array

def inference(img_path,
              model_path,
              save_path):
    
    if model_path.endswith(".axmodel"):
        import axengine as ort

    session = ort.InferenceSession(model_path)
    input_name = None
    for inp_meta in session.get_inputs():
        input_shape = inp_meta.shape[2:]
        input_name = inp_meta.name
        print(f"输入名称:{input_name},输入尺寸:{input_shape}")
    
    model_input_size = [1024, 1024]
    orig_im_bgr = cv2.imread(img_path)
    if orig_im_bgr is None:
        raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏")
    orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB)  # 转换为RGB格式 (H,W,3)
    orig_im_size = orig_im.shape[0:2]  
    image = preprocess_image(orig_im, model_input_size)
    
    t1 = time.time()
    result = session.run(None, {input_name: image})
    t2 = time.time()
    print(f"推理时间:{(t2-t1)*1000:.2f} ms")
   
    result_image = postprocess_image(result[0], orig_im_size)  # 得到单通道掩码 (H,W)
    orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)  # 读取所有通道(BGR/BGRA)
    mask = result_image  # 单通道掩码 (H,W),值范围0-255
    if orig_im_unchanged.shape[-1] == 3:  # 原图为BGR格式(无透明通道)
        b, g, r = cv2.split(orig_im_unchanged)
        a = mask  
        no_bg_image = cv2.merge((b, g, r, a))  # 合并为BGRA格式
    elif orig_im_unchanged.shape[-1] == 4:  # 原图为BGRA格式(已有透明通道)
        b, g, r, _ = cv2.split(orig_im_unchanged)
        a = mask
        no_bg_image = cv2.merge((b, g, r, a))
    else:
        raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)")
    
    if save_path.lower().endswith(('.jpg', '.jpeg')):
        cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR))
        print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}")
    else:
        cv2.imwrite(save_path, no_bg_image)
        print(f"推理完成,结果已保存至:{save_path}")

def parse_args() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="ax rmbg exsample")
    parser.add_argument("--model","-m", type=str, help="compiled.axmodel path")
    parser.add_argument("--img","-i", type=str, help="img path")
    parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)")

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    
    print(f"Command: {' '.join(sys.argv)}")
    print("Parameters:")
    print(f"  --model: {args.model}")
    print(f"  --img_path: {args.img}")
    print(f"  --save_path: {args.save_path}")

    inference(args.img, args.model, args.save_path)