#!/usr/bin/env python3 """ Clean batch-independence validator for the fixed det_500m model. Runs: A : ONNXRuntime, B=1 B : onnx2torch, B=1 Ci : onnx2torch, B=5, frame i (i=1..5) Compares each of the 5 frames (C1..C5) against A and B for all 9 output heads. If batch-independent, every Ci should equal both A and B (within float noise), and all Ci should equal each other. """ import argparse import cv2 import numpy as np import torch import onnx2torch import onnxruntime as ort DET_SIZE = (640, 640) parser = argparse.ArgumentParser() parser.add_argument('--model', required=True) parser.add_argument('--image', required=True) args = parser.parse_args() # ── Load models ────────────────────────────────────────────────────────────── print('Loading ONNXRuntime session ...') sess = ort.InferenceSession(args.model, providers=['CPUExecutionProvider']) inp_name = sess.get_inputs()[0].name print('Loading onnx2torch model ...') pt_model = onnx2torch.convert(args.model) pt_model.eval() # ── Preprocess image ───────────────────────────────────────────────────────── img = cv2.imread(args.image) assert img is not None, f'Could not read image: {args.image}' print(f'Image: {args.image} shape={img.shape}') blob = cv2.dnn.blobFromImage( cv2.resize(img, DET_SIZE), 1.0 / 128.0, DET_SIZE, (127.5, 127.5, 127.5), swapRB=True, ) t = torch.from_numpy(blob) # ── Run A, B, C(B=5) ───────────────────────────────────────────────────────── print('\nRunning A (ORT B=1), B (torch B=1), C (torch B=5)...') ort_out = sess.run(None, {inp_name: blob}) # A: list of np arrays with torch.no_grad(): pt_b1 = pt_model(t) # B: tuple of [1, anchors, K] pt_b5 = pt_model(torch.cat([t] * 5, dim=0)) # C: tuple of [5, anchors, K] print('\nB=5 output shapes:') for i, c in enumerate(pt_b5): print(f' out[{i}]: {list(c.shape)}') # ── Per-head comparison: A, B, C1..C5 ──────────────────────────────────────── print('\n' + '=' * 100) print('Per-head comparison (max abs diff)') print('=' * 100) header = f'{"head":<6} {"AvsB":>10} ' + ' '.join(f'{"AvsC"+str(i+1):>10}' for i in range(5)) \ + ' ' + ' '.join(f'{"BvsC"+str(i+1):>10}' for i in range(5)) print(header) print('-' * len(header)) THRESHOLD = 1e-4 all_pass = True for idx, (a, b, c) in enumerate(zip(ort_out, pt_b1, pt_b5)): a_np = a.reshape(-1) # ORT: [1, anchors, K] -> flat b_np = b.detach().numpy().reshape(-1) # torch B=1: [1, anchors, K] -> flat c_np = c.detach().numpy() # torch B=5: [5, anchors, K] per_frame_size = b_np.size # elements per frame frames = [c_np[i].reshape(-1) for i in range(5)] # 5 separate frame outputs avb = np.abs(a_np - b_np).max() avc = [np.abs(a_np - f).max() for f in frames] bvc = [np.abs(b_np - f).max() for f in frames] head_pass = avb < THRESHOLD and all(d < THRESHOLD for d in avc + bvc) all_pass = all_pass and head_pass row = f'out[{idx}] {avb:>10.6f} ' + ' '.join(f'{d:>10.6f}' for d in avc) \ + ' ' + ' '.join(f'{d:>10.6f}' for d in bvc) if not head_pass: row += ' FAIL' print(row) # ── Cross-frame check (Ci vs Cj) ───────────────────────────────────────────── print('\n' + '=' * 60) print('Cross-frame check: max diff between any pair Ci, Cj') print('=' * 60) print(f'{"head":<6} {"max_pairwise_diff":>20}') print('-' * 30) for idx, c in enumerate(pt_b5): c_np = c.detach().numpy() frames = [c_np[i].reshape(-1) for i in range(5)] max_pairwise = 0.0 for i in range(5): for j in range(i + 1, 5): d = np.abs(frames[i] - frames[j]).max() if d > max_pairwise: max_pairwise = d status = '' if max_pairwise < THRESHOLD else ' FAIL' print(f'out[{idx}] {max_pairwise:>20.6f}{status}') print('\n' + '=' * 60) print(f'OVERALL: {"PASS" if all_pass else "FAIL"} (threshold {THRESHOLD})') print('=' * 60)