"""TwinLiteNet8 ONNX inference — for edge deployment / cross-platform. Runs entirely via ONNX Runtime (no PyTorch needed at deploy time). Use CPUExecutionProvider for CPU, CUDAExecutionProvider for GPU, TensorRTExecutionProvider for TensorRT-accelerated runs on Jetson. Usage: python predict_onnx.py input.jpg --onnx twinlite8.onnx python predict_onnx.py --dir frames/ --out out/ --onnx twinlite8.onnx --provider CUDAExecutionProvider """ import argparse from pathlib import Path import cv2 import numpy as np import onnxruntime as ort NAMES = ["tree","ground","person","sky","road","mountain","building","background"] PALETTE = np.array([ [60,220,60],[40,100,160],[40,40,230],[230,200,60], [140,140,140],[180,60,180],[50,220,220],[100,100,100], ], dtype=np.uint8) TRAIN_W, TRAIN_H = 640, 360 def predict(sess, bgr_img): H, W = bgr_img.shape[:2] inp = cv2.resize(bgr_img, (TRAIN_W, TRAIN_H)) rgb = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 x = rgb.transpose(2, 0, 1)[None].astype(np.float32) # (1,3,H,W) logits = sess.run(None, {"input": x})[0] # (1,8,H,W) logits[:, 7, :, :] = -1e9 # v2: bg channel never trained pred_small = logits.argmax(1)[0].astype(np.uint8) # at training res if (H, W) != (TRAIN_H, TRAIN_W): return cv2.resize(pred_small, (W, H), interpolation=cv2.INTER_NEAREST) return pred_small def main(): ap = argparse.ArgumentParser() ap.add_argument("input", nargs="?") ap.add_argument("--dir") ap.add_argument("--out", default=".") ap.add_argument("--onnx", default="twinlite8.onnx") ap.add_argument("--provider", default=None, help="ONNX provider: CPUExecutionProvider | CUDAExecutionProvider | TensorrtExecutionProvider") args = ap.parse_args() if not args.input and not args.dir: ap.print_help(); return available = ort.get_available_providers() if args.provider: providers = [args.provider] else: # Auto-pick best for p in ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]: if p in available: providers = [p]; break print(f"available providers: {available}") print(f"using: {providers}") sess = ort.InferenceSession(args.onnx, providers=providers) print(f"actual provider: {sess.get_providers()}") out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True) paths = [] if args.dir: paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"}) if args.input: paths.append(Path(args.input)) for p in paths: img = cv2.imread(str(p)) if img is None: continue mask = predict(sess, img) cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask) overlay = cv2.addWeighted(img, 0.55, PALETTE[mask], 0.45, 0) cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay) counts = np.bincount(mask.flatten(), minlength=8) top = counts.argmax() print(f" {p.name:<50} top: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)") print(f"\noutputs -> {out_dir.resolve()}") if __name__ == "__main__": main()