DeepLabv3Plus / infer.py
lihongjie
first commit
fafd9e7
import argparse
import cv2
import os
import torch
from PIL import Image
import numpy as np
import axengine as axe
from datasets import VOCSegmentation, Cityscapes, cityscapes
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--img",
type=str,
required=True,
help="Path to input image.",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to axmodel model.",
)
return parser.parse_args()
def infer(img: str, model: str, viz: bool = False):
img_raw = cv2.imread(img)
image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (513,513))
image = image[None]
session = axe.InferenceSession(model)
pred = session.run(None, {"input": image})[0]
pred = torch.from_numpy(pred)
pred = pred.max(1)[1].cpu().numpy()[0] # HW
decode_fn = VOCSegmentation.decode_target
colorized_preds = decode_fn(pred).astype('uint8')
colorized_preds = Image.fromarray(colorized_preds)
colorized_preds.save("output-ax.png")
if __name__ == "__main__":
args = parse_args()
infer(**vars(args))