File size: 1,212 Bytes
fafd9e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import argparse
import cv2
import os
import torch
from PIL import Image
import numpy as np
import axengine as axe
from datasets import VOCSegmentation, Cityscapes, cityscapes
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--img",
type=str,
required=True,
help="Path to input image.",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to axmodel model.",
)
return parser.parse_args()
def infer(img: str, model: str, viz: bool = False):
img_raw = cv2.imread(img)
image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (513,513))
image = image[None]
session = axe.InferenceSession(model)
pred = session.run(None, {"input": image})[0]
pred = torch.from_numpy(pred)
pred = pred.max(1)[1].cpu().numpy()[0] # HW
decode_fn = VOCSegmentation.decode_target
colorized_preds = decode_fn(pred).astype('uint8')
colorized_preds = Image.fromarray(colorized_preds)
colorized_preds.save("output-ax.png")
if __name__ == "__main__":
args = parse_args()
infer(**vars(args))
|