File size: 4,548 Bytes
a12c7cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | #!/usr/bin/env python3
"""
Clean batch-independence validator for the fixed det_500m model.
Runs:
A : ONNXRuntime, B=1
B : onnx2torch, B=1
Ci : onnx2torch, B=5, frame i (i=1..5)
Compares each of the 5 frames (C1..C5) against A and B for all 9 output heads.
If batch-independent, every Ci should equal both A and B (within float noise),
and all Ci should equal each other.
"""
import argparse
import cv2
import numpy as np
import torch
import onnx2torch
import onnxruntime as ort
DET_SIZE = (640, 640)
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--image', required=True)
args = parser.parse_args()
# ββ Load models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('Loading ONNXRuntime session ...')
sess = ort.InferenceSession(args.model, providers=['CPUExecutionProvider'])
inp_name = sess.get_inputs()[0].name
print('Loading onnx2torch model ...')
pt_model = onnx2torch.convert(args.model)
pt_model.eval()
# ββ Preprocess image βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
img = cv2.imread(args.image)
assert img is not None, f'Could not read image: {args.image}'
print(f'Image: {args.image} shape={img.shape}')
blob = cv2.dnn.blobFromImage(
cv2.resize(img, DET_SIZE),
1.0 / 128.0, DET_SIZE,
(127.5, 127.5, 127.5), swapRB=True,
)
t = torch.from_numpy(blob)
# ββ Run A, B, C(B=5) βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print('\nRunning A (ORT B=1), B (torch B=1), C (torch B=5)...')
ort_out = sess.run(None, {inp_name: blob}) # A: list of np arrays
with torch.no_grad():
pt_b1 = pt_model(t) # B: tuple of [1, anchors, K]
pt_b5 = pt_model(torch.cat([t] * 5, dim=0)) # C: tuple of [5, anchors, K]
print('\nB=5 output shapes:')
for i, c in enumerate(pt_b5):
print(f' out[{i}]: {list(c.shape)}')
# ββ Per-head comparison: A, B, C1..C5 ββββββββββββββββββββββββββββββββββββββββ
print('\n' + '=' * 100)
print('Per-head comparison (max abs diff)')
print('=' * 100)
header = f'{"head":<6} {"AvsB":>10} ' + ' '.join(f'{"AvsC"+str(i+1):>10}' for i in range(5)) \
+ ' ' + ' '.join(f'{"BvsC"+str(i+1):>10}' for i in range(5))
print(header)
print('-' * len(header))
THRESHOLD = 1e-4
all_pass = True
for idx, (a, b, c) in enumerate(zip(ort_out, pt_b1, pt_b5)):
a_np = a.reshape(-1) # ORT: [1, anchors, K] -> flat
b_np = b.detach().numpy().reshape(-1) # torch B=1: [1, anchors, K] -> flat
c_np = c.detach().numpy() # torch B=5: [5, anchors, K]
per_frame_size = b_np.size # elements per frame
frames = [c_np[i].reshape(-1) for i in range(5)] # 5 separate frame outputs
avb = np.abs(a_np - b_np).max()
avc = [np.abs(a_np - f).max() for f in frames]
bvc = [np.abs(b_np - f).max() for f in frames]
head_pass = avb < THRESHOLD and all(d < THRESHOLD for d in avc + bvc)
all_pass = all_pass and head_pass
row = f'out[{idx}] {avb:>10.6f} ' + ' '.join(f'{d:>10.6f}' for d in avc) \
+ ' ' + ' '.join(f'{d:>10.6f}' for d in bvc)
if not head_pass:
row += ' FAIL'
print(row)
# ββ Cross-frame check (Ci vs Cj) βββββββββββββββββββββββββββββββββββββββββββββ
print('\n' + '=' * 60)
print('Cross-frame check: max diff between any pair Ci, Cj')
print('=' * 60)
print(f'{"head":<6} {"max_pairwise_diff":>20}')
print('-' * 30)
for idx, c in enumerate(pt_b5):
c_np = c.detach().numpy()
frames = [c_np[i].reshape(-1) for i in range(5)]
max_pairwise = 0.0
for i in range(5):
for j in range(i + 1, 5):
d = np.abs(frames[i] - frames[j]).max()
if d > max_pairwise:
max_pairwise = d
status = '' if max_pairwise < THRESHOLD else ' FAIL'
print(f'out[{idx}] {max_pairwise:>20.6f}{status}')
print('\n' + '=' * 60)
print(f'OVERALL: {"PASS" if all_pass else "FAIL"} (threshold {THRESHOLD})')
print('=' * 60)
|