import argparse import cv2 import os import torch from PIL import Image import numpy as np import axengine as axe from datasets import VOCSegmentation, Cityscapes, cityscapes def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--img", type=str, required=True, help="Path to input image.", ) parser.add_argument( "--model", type=str, required=True, help="Path to axmodel model.", ) return parser.parse_args() def infer(img: str, model: str, viz: bool = False): img_raw = cv2.imread(img) image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (513,513)) image = image[None] session = axe.InferenceSession(model) pred = session.run(None, {"input": image})[0] pred = torch.from_numpy(pred) pred = pred.max(1)[1].cpu().numpy()[0] # HW decode_fn = VOCSegmentation.decode_target colorized_preds = decode_fn(pred).astype('uint8') colorized_preds = Image.fromarray(colorized_preds) colorized_preds.save("output-ax.png") if __name__ == "__main__": args = parse_args() infer(**vars(args))