scrfd_640_batched / validate_ort_batch.py
ceyxprime's picture
Upload validate_ort_batch.py with huggingface_hub
254d8b6 verified
#!/usr/bin/env python3
"""
Validate that ONNXRuntime (not onnx2torch) gives batch-independent results
on the fixed model.
Runs:
A : ORT B=1
C : ORT B=5 (same image x5)
Checks per head:
- Output shape has batch dim outermost: [5, anchors, K]
- A vs C[i] max diff < 1e-4 for all 5 frames
- Cross-frame Ci vs Cj max diff < 1e-4 for all pairs
"""
import argparse
import cv2
import numpy as np
import onnxruntime as ort
DET_SIZE = (640, 640)
THRESHOLD = 1e-4
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--image', required=True)
args = parser.parse_args()
sess = ort.InferenceSession(args.model, providers=['CPUExecutionProvider'])
inp_name = sess.get_inputs()[0].name
img = cv2.imread(args.image)
assert img is not None, f'Could not read: {args.image}'
blob = cv2.dnn.blobFromImage(
cv2.resize(img, DET_SIZE), 1.0 / 128.0, DET_SIZE,
(127.5, 127.5, 127.5), swapRB=True,
) # (1, 3, 640, 640)
out_b1 = sess.run(None, {inp_name: blob})
out_b5 = sess.run(None, {inp_name: np.repeat(blob, 5, axis=0)})
print('ORT B=5 output shapes:')
for i, o in enumerate(out_b5):
print(f' out[{i}]: {list(o.shape)}')
print(f'\n{"head":<6} {"shape_ok":>10} ' +
' '.join(f'{"AvsC"+str(i+1):>10}' for i in range(5)) +
f' {"cross_frame":>12}')
print('-' * 90)
all_pass = True
for idx, (a, c) in enumerate(zip(out_b1, out_b5)):
# shape check: must be 3D with leading dim = 5
shape_ok = c.ndim == 3 and c.shape[0] == 5
a_flat = a.reshape(-1)
frames = [c[i].reshape(-1) for i in range(5)]
avc = [np.abs(a_flat - f).max() for f in frames]
max_pairwise = 0.0
for i in range(5):
for j in range(i + 1, 5):
d = np.abs(frames[i] - frames[j]).max()
if d > max_pairwise:
max_pairwise = d
head_pass = shape_ok and all(d < THRESHOLD for d in avc) and max_pairwise < THRESHOLD
all_pass = all_pass and head_pass
row = (f'out[{idx}] {"OK":>10} ' +
' '.join(f'{d:>10.6f}' for d in avc) +
f' {max_pairwise:>12.6f}')
if not head_pass:
row += ' FAIL'
print(row)
print('\n' + '=' * 50)
print(f'OVERALL: {"PASS" if all_pass else "FAIL"} (threshold {THRESHOLD})')
print('=' * 50)