| | import argparse |
| | import sys |
| | import cv2 |
| | import numpy as np |
| | import time |
| |
|
| | def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray: |
| | if len(im.shape) < 3: |
| | im = im[:, :, np.newaxis] |
| | im_np = np.transpose(im, (2, 0, 1)).astype(np.float32) |
| | im_np = np.expand_dims(im_np, axis=0) |
| | |
| | _, C, H_ori, W_ori = im_np.shape |
| | H_target, W_target = model_input_size |
| | |
| | x_target = np.linspace(0, W_ori - 1, W_target) |
| | y_target = np.linspace(0, H_ori - 1, H_target) |
| | xx_target, yy_target = np.meshgrid(x_target, y_target) |
| | |
| | x0 = np.floor(xx_target).astype(np.int32) |
| | x1 = np.minimum(x0 + 1, W_ori - 1) |
| | y0 = np.floor(yy_target).astype(np.int32) |
| | y1 = np.minimum(y0 + 1, H_ori - 1) |
| | |
| | wx0 = xx_target - x0 |
| | wx1 = 1 - wx0 |
| | wy0 = yy_target - y0 |
| | wy1 = 1 - wy0 |
| | |
| | im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32) |
| | for c in range(C): |
| | channel_data = im_np[0, c, :, :] |
| | top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1] |
| | bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1] |
| | im_interp[0, c, :, :] = wy1 * top + wy0 * bottom |
| | |
| | image = (im_interp / 1.0).astype(np.uint8) |
| | |
| | return image |
| |
|
| | def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray: |
| | result_np = np.squeeze(result, axis=0) |
| | C, H_ori, W_ori = result_np.shape |
| | H_target, W_target = im_size |
| |
|
| | x_target = np.linspace(0, W_ori - 1, W_target) |
| | y_target = np.linspace(0, H_ori - 1, H_target) |
| | xx_target, yy_target = np.meshgrid(x_target, y_target) |
| |
|
| | x0 = np.floor(xx_target).astype(np.int32) |
| | x1 = np.minimum(x0 + 1, W_ori - 1) |
| | y0 = np.floor(yy_target).astype(np.int32) |
| | y1 = np.minimum(y0 + 1, H_ori - 1) |
| |
|
| | wx0 = xx_target - x0 |
| | wx1 = 1 - wx0 |
| | wy0 = yy_target - y0 |
| | wy1 = 1 - wy0 |
| |
|
| | result_interp = np.zeros((C, H_target, W_target), dtype=np.float32) |
| | for c in range(C): |
| | channel_data = result_np[c, :, :] |
| | top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1] |
| | bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1] |
| | result_interp[c, :, :] = wy1 * top + wy0 * bottom |
| | |
| | ma = np.max(result_interp) |
| | mi = np.min(result_interp) |
| |
|
| | result_norm = (result_interp - mi) / (ma - mi + 1e-8) |
| | result_scaled = result_norm * 255 |
| | im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8) |
| | im_array = np.squeeze(im_array) |
| | return im_array |
| |
|
| | def inference(img_path, |
| | model_path, |
| | save_path): |
| | |
| | if model_path.endswith(".axmodel"): |
| | import axengine as ort |
| |
|
| | session = ort.InferenceSession(model_path) |
| | input_name = None |
| | for inp_meta in session.get_inputs(): |
| | input_shape = inp_meta.shape[2:] |
| | input_name = inp_meta.name |
| | print(f"输入名称:{input_name},输入尺寸:{input_shape}") |
| | |
| | model_input_size = [1024, 1024] |
| | orig_im_bgr = cv2.imread(img_path) |
| | if orig_im_bgr is None: |
| | raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏") |
| | orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB) |
| | orig_im_size = orig_im.shape[0:2] |
| | image = preprocess_image(orig_im, model_input_size) |
| | |
| | t1 = time.time() |
| | result = session.run(None, {input_name: image}) |
| | t2 = time.time() |
| | print(f"推理时间:{(t2-t1)*1000:.2f} ms") |
| | |
| | result_image = postprocess_image(result[0], orig_im_size) |
| | orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) |
| | mask = result_image |
| | if orig_im_unchanged.shape[-1] == 3: |
| | b, g, r = cv2.split(orig_im_unchanged) |
| | a = mask |
| | no_bg_image = cv2.merge((b, g, r, a)) |
| | elif orig_im_unchanged.shape[-1] == 4: |
| | b, g, r, _ = cv2.split(orig_im_unchanged) |
| | a = mask |
| | no_bg_image = cv2.merge((b, g, r, a)) |
| | else: |
| | raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)") |
| | |
| | if save_path.lower().endswith(('.jpg', '.jpeg')): |
| | cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR)) |
| | print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}") |
| | else: |
| | cv2.imwrite(save_path, no_bg_image) |
| | print(f"推理完成,结果已保存至:{save_path}") |
| |
|
| | def parse_args() -> argparse.ArgumentParser: |
| | parser = argparse.ArgumentParser(description="ax rmbg exsample") |
| | parser.add_argument("--model","-m", type=str, help="compiled.axmodel path") |
| | parser.add_argument("--img","-i", type=str, help="img path") |
| | parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)") |
| |
|
| | args = parser.parse_args() |
| | return args |
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | |
| | print(f"Command: {' '.join(sys.argv)}") |
| | print("Parameters:") |
| | print(f" --model: {args.model}") |
| | print(f" --img_path: {args.img}") |
| | print(f" --save_path: {args.save_path}") |
| |
|
| | inference(args.img, args.model, args.save_path) |