| """TwinLiteNet8 ONNX inference — for edge deployment / cross-platform. |
| |
| Runs entirely via ONNX Runtime (no PyTorch needed at deploy time). |
| Use CPUExecutionProvider for CPU, CUDAExecutionProvider for GPU, |
| TensorRTExecutionProvider for TensorRT-accelerated runs on Jetson. |
| |
| Usage: |
| python predict_onnx.py input.jpg --onnx twinlite8.onnx |
| python predict_onnx.py --dir frames/ --out out/ --onnx twinlite8.onnx --provider CUDAExecutionProvider |
| """ |
| import argparse |
| from pathlib import Path |
| import cv2 |
| import numpy as np |
| import onnxruntime as ort |
|
|
| NAMES = ["tree","ground","person","sky","road","mountain","building","background"] |
| PALETTE = np.array([ |
| [60,220,60],[40,100,160],[40,40,230],[230,200,60], |
| [140,140,140],[180,60,180],[50,220,220],[100,100,100], |
| ], dtype=np.uint8) |
| TRAIN_W, TRAIN_H = 640, 360 |
|
|
|
|
| def predict(sess, bgr_img): |
| H, W = bgr_img.shape[:2] |
| inp = cv2.resize(bgr_img, (TRAIN_W, TRAIN_H)) |
| rgb = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
| x = rgb.transpose(2, 0, 1)[None].astype(np.float32) |
| logits = sess.run(None, {"input": x})[0] |
| logits[:, 7, :, :] = -1e9 |
| pred_small = logits.argmax(1)[0].astype(np.uint8) |
| if (H, W) != (TRAIN_H, TRAIN_W): |
| return cv2.resize(pred_small, (W, H), interpolation=cv2.INTER_NEAREST) |
| return pred_small |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("input", nargs="?") |
| ap.add_argument("--dir") |
| ap.add_argument("--out", default=".") |
| ap.add_argument("--onnx", default="twinlite8.onnx") |
| ap.add_argument("--provider", default=None, |
| help="ONNX provider: CPUExecutionProvider | CUDAExecutionProvider | TensorrtExecutionProvider") |
| args = ap.parse_args() |
|
|
| if not args.input and not args.dir: |
| ap.print_help(); return |
|
|
| available = ort.get_available_providers() |
| if args.provider: |
| providers = [args.provider] |
| else: |
| |
| for p in ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]: |
| if p in available: providers = [p]; break |
| print(f"available providers: {available}") |
| print(f"using: {providers}") |
| sess = ort.InferenceSession(args.onnx, providers=providers) |
| print(f"actual provider: {sess.get_providers()}") |
|
|
| out_dir = Path(args.out); out_dir.mkdir(parents=True, exist_ok=True) |
| paths = [] |
| if args.dir: |
| paths = sorted(p for p in Path(args.dir).iterdir() if p.suffix.lower() in {".jpg",".jpeg",".png",".bmp"}) |
| if args.input: paths.append(Path(args.input)) |
|
|
| for p in paths: |
| img = cv2.imread(str(p)) |
| if img is None: continue |
| mask = predict(sess, img) |
| cv2.imwrite(str(out_dir / f"{p.stem}_pred.png"), mask) |
| overlay = cv2.addWeighted(img, 0.55, PALETTE[mask], 0.45, 0) |
| cv2.imwrite(str(out_dir / f"{p.stem}_overlay.jpg"), overlay) |
| counts = np.bincount(mask.flatten(), minlength=8) |
| top = counts.argmax() |
| print(f" {p.name:<50} top: {NAMES[top]} ({100*counts[top]/counts.sum():.1f}%)") |
|
|
| print(f"\noutputs -> {out_dir.resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|