| | import os
|
| | import numpy as np
|
| | import cv2
|
| | from tqdm import tqdm
|
| | import argparse
|
| | from mmseg.apis import init_model, inference_model
|
| |
|
| |
|
| | def process_single_img(img_path, model, outpath, palette_dict):
|
| |
|
| | img_bgr = cv2.imread(img_path)
|
| |
|
| | result = inference_model(model, img_bgr)
|
| | pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
|
| |
|
| |
|
| | pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
|
| | for idx in palette_dict.keys():
|
| | pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
|
| | pred_mask_bgr = pred_mask_bgr.astype('uint8')
|
| |
|
| | save_path = os.path.join(outpath, os.path.basename(img_path))
|
| | cv2.imwrite(save_path, pred_mask_bgr)
|
| |
|
| |
|
| |
|
| | def main(args):
|
| |
|
| | model = init_model(args.config_file, args.checkpoint_file, device=args.device)
|
| |
|
| |
|
| | palette = [
|
| | ['background', [0, 0, 0]],
|
| | ['red', [0, 0, 255]]
|
| | ]
|
| | palette_dict = {idx: each[1] for idx, each in enumerate(palette)}
|
| |
|
| |
|
| | if not os.path.exists(args.outpath):
|
| | os.mkdir(args.outpath)
|
| |
|
| |
|
| | for img_name in tqdm(os.listdir(args.data_folder)):
|
| | img_path = os.path.join(args.data_folder, img_name)
|
| | process_single_img(img_path, model, args.outpath, palette_dict)
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | parser = argparse.ArgumentParser(description="Process images for semantic segmentation inference.")
|
| | parser.add_argument('-d','--data_folder', type=str, required=True, help="Path to the folder containing input images.")
|
| | parser.add_argument('-m','--config_file', type=str, required=True, help="Path to the model config file.")
|
| | parser.add_argument('-pth','--checkpoint_file', type=str, required=True, help="Path to the model checkpoint file.")
|
| | parser.add_argument('-o','--outpath', type=str, help="Path to save the output images.")
|
| | parser.add_argument('--device', type=str, default='cuda:0', help="Device to run the model (e.g., 'cuda:0', 'cpu').")
|
| |
|
| | args = parser.parse_args()
|
| | main(args)
|
| | |