TwinLiteNet8 / predict_onnx.py
WEN0256's picture
Initial release: TwinLiteNet8 (0.44M params, 7-class orchard semantic seg, edge-deployment ready)
f5cc6c0 verified
"""TwinLiteNet8 ONNX inference — for edge deployment / cross-platform.
Runs entirely via ONNX Runtime (no PyTorch needed at deploy time).
Use CPUExecutionProvider for CPU, CUDAExecutionProvider for GPU,
TensorRTExecutionProvider for TensorRT-accelerated runs on Jetson.
Usage:
python predict_onnx.py input.jpg --onnx twinlite8.onnx
python predict_onnx.py --dir frames/ --out out/ --onnx twinlite8.onnx --provider CUDAExecutionProvider
"""
import argparse
from pathlib import Path
import cv2
import numpy as np
import onnxruntime as ort
NAMES = ["tree","ground","person","sky","road","mountain","building","background"]
PALETTE = np.array([
[60,220,60],[40,100,160],[40,40,230],[230,200,60],
[140,140,140],[180,60,180],[50,220,220],[100,100,100],
], dtype=np.uint8)
TRAIN_W, TRAIN_H = 640, 360
def predict(sess, bgr_img):
H, W = bgr_img.shape[:2]
inp = cv2.resize(bgr_img, (TRAIN_W, TRAIN_H))
rgb = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
x = rgb.transpose(2, 0, 1)[None].astype(np.float32) # (1,3,H,W)
logits = sess.run(None, {"input": x})[0] # (1,8,H,W)
logits[:, 7, :, :] = -1e9 # v2: bg channel never trained
pred_small = logits.argmax(1)[0].astype(np.uint8) # at training res
if (H, W) != (TRAIN_H, TRAIN_W):
return cv2.resize(pred_small, (W, H), interpolation=cv2.INTER_NEAREST)
return pred_small
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input", nargs="?")
ap.add_argument("--dir")
ap.add_argument("--out", default=".")
ap.add_argument("--onnx", default="twinlite8.onnx")
ap.add_argument("--provider", default=None,
help="ONNX provider: CPUExecutionProvider | CUDAExecutionProvider | TensorrtExecutionProvider")
args = ap.parse_args()
if not args.input and not args.dir:
ap.print_help(); return
available = ort.get_available_providers()
if args.provider:
providers = [args.provider]
else:
# Auto-pick best
for p in ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]:
if p in available: providers = [p]; break
print(f"available providers: {available}")
print(f"using: {providers}")
sess = ort.InferenceSession(args.onnx, providers=providers)
print(f"actual provider: {sess.get_providers()}")
out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True)
paths = []
if args.dir:
paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"})
if args.input: paths.append(Path(args.input))
for p in paths:
img = cv2.imread(str(p))
if img is None: continue
mask = predict(sess, img)
cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask)
overlay = cv2.addWeighted(img, 0.55, PALETTE[mask], 0.45, 0)
cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay)
counts = np.bincount(mask.flatten(), minlength=8)
top = counts.argmax()
print(f" {p.name:<50} top: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)")
print(f"\noutputs -> {out_dir.resolve()}")
if __name__ == "__main__":
main()