File size: 4,548 Bytes
a12c7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""
Clean batch-independence validator for the fixed det_500m model.

Runs:
  A  : ONNXRuntime, B=1
  B  : onnx2torch,  B=1
  Ci : onnx2torch,  B=5, frame i (i=1..5)

Compares each of the 5 frames (C1..C5) against A and B for all 9 output heads.
If batch-independent, every Ci should equal both A and B (within float noise),
and all Ci should equal each other.
"""
import argparse
import cv2
import numpy as np
import torch
import onnx2torch
import onnxruntime as ort

DET_SIZE = (640, 640)

parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--image', required=True)
args = parser.parse_args()

# ── Load models ──────────────────────────────────────────────────────────────
print('Loading ONNXRuntime session ...')
sess = ort.InferenceSession(args.model, providers=['CPUExecutionProvider'])
inp_name = sess.get_inputs()[0].name

print('Loading onnx2torch model ...')
pt_model = onnx2torch.convert(args.model)
pt_model.eval()

# ── Preprocess image ─────────────────────────────────────────────────────────
img = cv2.imread(args.image)
assert img is not None, f'Could not read image: {args.image}'
print(f'Image: {args.image}  shape={img.shape}')

blob = cv2.dnn.blobFromImage(
    cv2.resize(img, DET_SIZE),
    1.0 / 128.0, DET_SIZE,
    (127.5, 127.5, 127.5), swapRB=True,
)
t = torch.from_numpy(blob)

# ── Run A, B, C(B=5) ─────────────────────────────────────────────────────────
print('\nRunning A (ORT B=1), B (torch B=1), C (torch B=5)...')
ort_out = sess.run(None, {inp_name: blob})           # A: list of np arrays
with torch.no_grad():
    pt_b1 = pt_model(t)                              # B: tuple of [1, anchors, K]
    pt_b5 = pt_model(torch.cat([t] * 5, dim=0))      # C: tuple of [5, anchors, K]

print('\nB=5 output shapes:')
for i, c in enumerate(pt_b5):
    print(f'  out[{i}]: {list(c.shape)}')

# ── Per-head comparison: A, B, C1..C5 ────────────────────────────────────────
print('\n' + '=' * 100)
print('Per-head comparison (max abs diff)')
print('=' * 100)
header = f'{"head":<6} {"AvsB":>10} ' + ' '.join(f'{"AvsC"+str(i+1):>10}' for i in range(5)) \
         + ' ' + ' '.join(f'{"BvsC"+str(i+1):>10}' for i in range(5))
print(header)
print('-' * len(header))

THRESHOLD = 1e-4
all_pass = True

for idx, (a, b, c) in enumerate(zip(ort_out, pt_b1, pt_b5)):
    a_np = a.reshape(-1)                             # ORT: [1, anchors, K] -> flat
    b_np = b.detach().numpy().reshape(-1)            # torch B=1: [1, anchors, K] -> flat
    c_np = c.detach().numpy()                        # torch B=5: [5, anchors, K]

    per_frame_size = b_np.size                        # elements per frame
    frames = [c_np[i].reshape(-1) for i in range(5)]  # 5 separate frame outputs

    avb = np.abs(a_np - b_np).max()
    avc = [np.abs(a_np - f).max() for f in frames]
    bvc = [np.abs(b_np - f).max() for f in frames]

    head_pass = avb < THRESHOLD and all(d < THRESHOLD for d in avc + bvc)
    all_pass = all_pass and head_pass

    row = f'out[{idx}] {avb:>10.6f} ' + ' '.join(f'{d:>10.6f}' for d in avc) \
          + ' ' + ' '.join(f'{d:>10.6f}' for d in bvc)
    if not head_pass:
        row += '  FAIL'
    print(row)

# ── Cross-frame check (Ci vs Cj) ─────────────────────────────────────────────
print('\n' + '=' * 60)
print('Cross-frame check: max diff between any pair Ci, Cj')
print('=' * 60)
print(f'{"head":<6} {"max_pairwise_diff":>20}')
print('-' * 30)
for idx, c in enumerate(pt_b5):
    c_np = c.detach().numpy()
    frames = [c_np[i].reshape(-1) for i in range(5)]
    max_pairwise = 0.0
    for i in range(5):
        for j in range(i + 1, 5):
            d = np.abs(frames[i] - frames[j]).max()
            if d > max_pairwise:
                max_pairwise = d
    status = '' if max_pairwise < THRESHOLD else '  FAIL'
    print(f'out[{idx}] {max_pairwise:>20.6f}{status}')

print('\n' + '=' * 60)
print(f'OVERALL: {"PASS" if all_pass else "FAIL"} (threshold {THRESHOLD})')
print('=' * 60)