#!/usr/bin/env python3 """ Validate that ONNXRuntime (not onnx2torch) gives batch-independent results on the fixed model. Runs: A : ORT B=1 C : ORT B=5 (same image x5) Checks per head: - Output shape has batch dim outermost: [5, anchors, K] - A vs C[i] max diff < 1e-4 for all 5 frames - Cross-frame Ci vs Cj max diff < 1e-4 for all pairs """ import argparse import cv2 import numpy as np import onnxruntime as ort DET_SIZE = (640, 640) THRESHOLD = 1e-4 parser = argparse.ArgumentParser() parser.add_argument('--model', required=True) parser.add_argument('--image', required=True) args = parser.parse_args() sess = ort.InferenceSession(args.model, providers=['CPUExecutionProvider']) inp_name = sess.get_inputs()[0].name img = cv2.imread(args.image) assert img is not None, f'Could not read: {args.image}' blob = cv2.dnn.blobFromImage( cv2.resize(img, DET_SIZE), 1.0 / 128.0, DET_SIZE, (127.5, 127.5, 127.5), swapRB=True, ) # (1, 3, 640, 640) out_b1 = sess.run(None, {inp_name: blob}) out_b5 = sess.run(None, {inp_name: np.repeat(blob, 5, axis=0)}) print('ORT B=5 output shapes:') for i, o in enumerate(out_b5): print(f' out[{i}]: {list(o.shape)}') print(f'\n{"head":<6} {"shape_ok":>10} ' + ' '.join(f'{"AvsC"+str(i+1):>10}' for i in range(5)) + f' {"cross_frame":>12}') print('-' * 90) all_pass = True for idx, (a, c) in enumerate(zip(out_b1, out_b5)): # shape check: must be 3D with leading dim = 5 shape_ok = c.ndim == 3 and c.shape[0] == 5 a_flat = a.reshape(-1) frames = [c[i].reshape(-1) for i in range(5)] avc = [np.abs(a_flat - f).max() for f in frames] max_pairwise = 0.0 for i in range(5): for j in range(i + 1, 5): d = np.abs(frames[i] - frames[j]).max() if d > max_pairwise: max_pairwise = d head_pass = shape_ok and all(d < THRESHOLD for d in avc) and max_pairwise < THRESHOLD all_pass = all_pass and head_pass row = (f'out[{idx}] {"OK":>10} ' + ' '.join(f'{d:>10.6f}' for d in avc) + f' {max_pairwise:>12.6f}') if not head_pass: row += ' FAIL' print(row) print('\n' + '=' * 50) print(f'OVERALL: {"PASS" if all_pass else "FAIL"} (threshold {THRESHOLD})') print('=' * 50)