|
|
"""Benchmark all ONNX models — timing, outputs, DLL vs pure ONNX comparison.""" |
|
|
import time |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
import onnx |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
|
|
|
MODELS_DIR = Path("oneocr_extracted/onnx_models") |
|
|
CONFIG_DIR = Path("oneocr_extracted/config_data") |
|
|
|
|
|
|
|
|
def load_char_map(path): |
|
|
idx2char, blank_idx = {}, 0 |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.rstrip("\n") |
|
|
if not line: continue |
|
|
sp = line.rfind(" ") |
|
|
if sp <= 0: continue |
|
|
char, idx = line[:sp], int(line[sp + 1:]) |
|
|
if char == "<blank>": blank_idx = idx |
|
|
elif char == "<space>": idx2char[idx] = " " |
|
|
else: idx2char[idx] = char |
|
|
return idx2char, blank_idx |
|
|
|
|
|
|
|
|
def ctc_decode(lp, idx2char, blank): |
|
|
if lp.ndim == 3: |
|
|
lp = lp[:, 0, :] if lp.shape[1] == 1 else lp[0] |
|
|
indices = np.argmax(lp, axis=-1) |
|
|
chars, prev = [], -1 |
|
|
for i in indices: |
|
|
if i != prev and i != blank: |
|
|
chars.append(idx2char.get(int(i), f"[{i}]")) |
|
|
prev = i |
|
|
return "".join(chars) |
|
|
|
|
|
|
|
|
print("=" * 100) |
|
|
print(" ONEOCR FULL MODEL BENCHMARK") |
|
|
print("=" * 100) |
|
|
|
|
|
img = Image.open("image.png").convert("RGB") |
|
|
w, h = img.size |
|
|
print(f"\n Test image: {w}x{h}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scale_det = 800 / max(h, w) |
|
|
dh = (int(h * scale_det) + 31) // 32 * 32 |
|
|
dw = (int(w * scale_det) + 31) // 32 * 32 |
|
|
img_det = img.resize((dw, dh), Image.LANCZOS) |
|
|
arr_det = np.array(img_det, dtype=np.float32)[:, :, ::-1] - [102.9801, 115.9465, 122.7717] |
|
|
det_data = arr_det.transpose(2, 0, 1)[np.newaxis].astype(np.float32) |
|
|
det_iminfo = np.array([[dh, dw, scale_det]], dtype=np.float32) |
|
|
|
|
|
|
|
|
target_h = 60 |
|
|
scale_rec = target_h / h |
|
|
new_w = max(int(w * scale_rec), 32) |
|
|
new_w = (new_w + 3) // 4 * 4 |
|
|
img_rec = img.resize((new_w, target_h), Image.LANCZOS) |
|
|
rec_data = (np.array(img_rec, dtype=np.float32) / 255.0).transpose(2, 0, 1)[np.newaxis] |
|
|
rec_seq = np.array([new_w // 4], dtype=np.int32) |
|
|
|
|
|
|
|
|
CHAR_MAPS = { |
|
|
2: "chunk_37_char2ind.char2ind.txt", |
|
|
3: "chunk_40_char2ind.char2ind.txt", |
|
|
4: "chunk_43_char2ind.char2ind.txt", |
|
|
5: "chunk_47_char2ind.char2ind.txt", |
|
|
6: "chunk_50_char2ind.char2ind.txt", |
|
|
7: "chunk_53_char2ind.char2ind.txt", |
|
|
8: "chunk_57_char2ind.char2ind.txt", |
|
|
9: "chunk_61_char2ind.char2ind.txt", |
|
|
10: "chunk_64_char2ind.char2ind.txt", |
|
|
} |
|
|
|
|
|
SCRIPT_NAMES = { |
|
|
0: "Detector", 1: "ScriptID", 2: "Latin", 3: "CJK", 4: "Arabic", |
|
|
5: "Cyrillic", 6: "Devanagari", 7: "Greek", 8: "Hebrew", 9: "Tamil", 10: "Thai", |
|
|
} |
|
|
for i in range(11, 22): SCRIPT_NAMES[i] = f"LangSm_{i-11}" |
|
|
for i in range(22, 33): SCRIPT_NAMES[i] = f"LangMd_{i-22}" |
|
|
SCRIPT_NAMES[33] = "LineLayout" |
|
|
|
|
|
|
|
|
print(f"{'#':>3} {'Name':15s} {'KB':>8} {'Load ms':>8} {'Run ms':>8} {'Runs':>5} {'Output Shape':30s} {'RT':>6} {'Text'}") |
|
|
print("-" * 100) |
|
|
|
|
|
total_load = 0 |
|
|
total_run = 0 |
|
|
|
|
|
for f in sorted(MODELS_DIR.glob("*.onnx")): |
|
|
idx = int(f.name.split("_")[1]) |
|
|
size_kb = f.stat().st_size // 1024 |
|
|
name = SCRIPT_NAMES.get(idx, f"model_{idx}") |
|
|
|
|
|
|
|
|
t0 = time.perf_counter() |
|
|
try: |
|
|
sess = ort.InferenceSession(str(f), providers=["CPUExecutionProvider"]) |
|
|
t_load = (time.perf_counter() - t0) * 1000 |
|
|
rt = "OK" |
|
|
except Exception as e: |
|
|
t_load = (time.perf_counter() - t0) * 1000 |
|
|
rt = "CUSTOM" |
|
|
print(f"{idx:>3} {name:15s} {size_kb:>8} {t_load:>8.1f} {'---':>8} {'---':>5} {'N/A (custom ops)':30s} {rt:>6}") |
|
|
continue |
|
|
|
|
|
total_load += t_load |
|
|
|
|
|
|
|
|
input_names = [i.name for i in sess.get_inputs()] |
|
|
|
|
|
if idx == 0: |
|
|
feeds = {"data": det_data, "im_info": det_iminfo} |
|
|
elif idx == 1: |
|
|
feeds = {"data": rec_data} |
|
|
if "seq_lengths" in input_names: |
|
|
feeds["seq_lengths"] = rec_seq |
|
|
elif idx <= 10: |
|
|
feeds = {"data": rec_data, "seq_lengths": rec_seq} |
|
|
else: |
|
|
|
|
|
feeds = {} |
|
|
for inp in sess.get_inputs(): |
|
|
shape = [] |
|
|
for d in inp.shape: |
|
|
shape.append(d if isinstance(d, int) and d > 0 else 1) |
|
|
if "int" in inp.type: |
|
|
feeds[inp.name] = np.ones(shape, dtype=np.int32) |
|
|
else: |
|
|
feeds[inp.name] = np.random.randn(*shape).astype(np.float32) |
|
|
|
|
|
|
|
|
try: |
|
|
sess.run(None, feeds) |
|
|
except Exception as e: |
|
|
print(f"{idx:>3} {name:15s} {size_kb:>8} {t_load:>8.1f} {'ERR':>8} {'---':>5} {str(e)[:30]:30s} {rt:>6}") |
|
|
continue |
|
|
|
|
|
|
|
|
n_runs = 5 |
|
|
t0 = time.perf_counter() |
|
|
for _ in range(n_runs): |
|
|
outputs = sess.run(None, feeds) |
|
|
t_run = (time.perf_counter() - t0) / n_runs * 1000 |
|
|
total_run += t_run |
|
|
|
|
|
out_shape = str(outputs[0].shape) |
|
|
|
|
|
|
|
|
text = "" |
|
|
if 2 <= idx <= 10 and idx in CHAR_MAPS: |
|
|
cm_path = CONFIG_DIR / CHAR_MAPS[idx] |
|
|
if cm_path.exists(): |
|
|
idx2char, blank_idx = load_char_map(str(cm_path)) |
|
|
text = ctc_decode(outputs[0], idx2char, blank_idx) |
|
|
|
|
|
print(f"{idx:>3} {name:15s} {size_kb:>8} {t_load:>8.1f} {t_run:>8.1f} {n_runs:>5} {out_shape:30s} {rt:>6} {text}") |
|
|
|
|
|
print("-" * 100) |
|
|
print(f" {'TOTAL':15s} {'':>8} {total_load:>8.1f} {total_run:>8.1f}") |
|
|
|
|
|
|
|
|
print(f"\n{'=' * 100}") |
|
|
print(" DLL vs ONNX COMPARISON") |
|
|
print("=" * 100) |
|
|
|
|
|
print(""" |
|
|
┌─────────────────────────────────────────────────────────────────────────────┐ |
|
|
│ Feature │ DLL (oneocr.dll) │ Pure ONNX │ |
|
|
├───────────────────────────┼────────────────────────┼────────────────────────┤ |
|
|
│ Platform │ Windows only │ Any (Linux/Mac/Win) │ |
|
|
│ Text detection (boxes) │ YES (4-point quads) │ Raw FPN maps (need PP) │ |
|
|
│ Text recognition │ YES (full pipeline) │ YES (CTC decode) │ |
|
|
│ Word confidence │ YES (per-word float) │ YES (from logprobs) │ |
|
|
│ Line bounding boxes │ YES (quadrilateral) │ NO (need PixelLink PP) │ |
|
|
│ Word bounding boxes │ YES (quadrilateral) │ NO (need PixelLink PP) │ |
|
|
│ Image angle/rotation │ YES (GetImageAngle) │ NO (not in models) │ |
|
|
│ Script detection │ YES (automatic) │ model_01 (standalone) │ |
|
|
│ Language models (LM) │ YES (built-in custom) │ NO (custom ops needed) │ |
|
|
│ Multi-script support │ YES (auto-switch) │ Manual script select │ |
|
|
│ Dependencies │ oneocr.dll + ort.dll │ onnxruntime + numpy │ |
|
|
│ Size │ ~100MB (DLL+model) │ ~45MB (12 ONNX models) │ |
|
|
│ Latency (typical) │ ~50-100ms │ ~30-80ms (recog only) │ |
|
|
│ Custom ops needed │ NO (built-in) │ 23/34 models blocked │ |
|
|
└─────────────────────────────────────────────────────────────────────────────┘ |
|
|
""") |
|
|
|
|
|
|
|
|
print("=" * 100) |
|
|
print(" DETAILED MODEL OUTPUT ANALYSIS") |
|
|
print("=" * 100) |
|
|
|
|
|
for f in sorted(MODELS_DIR.glob("*.onnx")): |
|
|
idx = int(f.name.split("_")[1]) |
|
|
m = onnx.load(str(f)) |
|
|
name = SCRIPT_NAMES.get(idx, f"model_{idx}") |
|
|
|
|
|
print(f"\n model_{idx:02d} ({name}):") |
|
|
|
|
|
|
|
|
for inp in m.graph.input: |
|
|
dims = [] |
|
|
if inp.type.tensor_type.HasField("shape"): |
|
|
for d in inp.type.tensor_type.shape.dim: |
|
|
dims.append(str(d.dim_value) if d.dim_value else d.dim_param or "?") |
|
|
elem_type = inp.type.tensor_type.elem_type |
|
|
type_map = {1: "float32", 6: "int32", 7: "int64", 10: "float16"} |
|
|
print(f" IN: {inp.name:20s} [{','.join(dims):20s}] {type_map.get(elem_type, f'type{elem_type}')}") |
|
|
|
|
|
|
|
|
for out in m.graph.output: |
|
|
dims = [] |
|
|
if out.type.tensor_type.HasField("shape"): |
|
|
for d in out.type.tensor_type.shape.dim: |
|
|
dims.append(str(d.dim_value) if d.dim_value else d.dim_param or "?") |
|
|
elem_type = out.type.tensor_type.elem_type |
|
|
print(f" OUT: {out.name:20s} [{','.join(dims):20s}] {type_map.get(elem_type, f'type{elem_type}')}") |
|
|
|
|
|
|
|
|
domains = [o.domain for o in m.opset_import if o.domain] |
|
|
if domains: |
|
|
print(f" CUSTOM OPS: {', '.join(domains)}") |
|
|
|
|
|
|
|
|
op_counts = {} |
|
|
for node in m.graph.node: |
|
|
key = f"{node.domain}::{node.op_type}" if node.domain else node.op_type |
|
|
op_counts[key] = op_counts.get(key, 0) + 1 |
|
|
|
|
|
if idx <= 1 or idx == 33: |
|
|
top_ops = sorted(op_counts.items(), key=lambda x: -x[1])[:8] |
|
|
print(f" TOP OPS: {', '.join(f'{op}({n})' for op, n in top_ops)}") |
|
|
|