oneocr / _archive /test_onnx_inference.py
OneOCR Dev
OneOCR - reverse engineering complete, ONNX pipeline 53% match rate
ce847d4
"""
Test extracted ONNX models on image.png — pure onnxruntime inference.
No Windows DLL needed — runs on any platform with onnxruntime + Pillow.
Pipeline:
1. model_00 (detector) — finds text regions via PixelLink FPN
2. model_01 (script ID) — identifies writing script (Latin, CJK, etc.)
3. model_02..10 (recognizers) — CTC character recognition per script
4. char2ind.txt → decode character indices to text
Preprocessing: RGB, height=60px, pixels / 255.0 (range [0, 1])
Postprocessing: CTC greedy decode with <blank> token removal
"""
import sys
from pathlib import Path
import numpy as np
import onnxruntime as ort
from PIL import Image
MODELS_DIR = Path("oneocr_extracted/onnx_models")
CONFIG_DIR = Path("oneocr_extracted/config_data")
# ─── Model registry ─────────────────────────────────────────────────────────
# model_idx -> (role, script, char2ind_file)
MODEL_REGISTRY: dict[int, tuple[str, str, str | None]] = {
0: ("detector", "universal", None),
1: ("script_id", "universal", None),
2: ("recognizer", "Latin", "chunk_37_char2ind.char2ind.txt"),
3: ("recognizer", "CJK", "chunk_40_char2ind.char2ind.txt"),
4: ("recognizer", "Arabic", "chunk_43_char2ind.char2ind.txt"),
5: ("recognizer", "Cyrillic", "chunk_47_char2ind.char2ind.txt"),
6: ("recognizer", "Devanagari", "chunk_50_char2ind.char2ind.txt"),
7: ("recognizer", "Greek", "chunk_53_char2ind.char2ind.txt"),
8: ("recognizer", "Hebrew", "chunk_57_char2ind.char2ind.txt"),
9: ("recognizer", "Tamil", "chunk_61_char2ind.char2ind.txt"),
10: ("recognizer", "Thai", "chunk_64_char2ind.char2ind.txt"),
}
def load_char_map(path: str) -> tuple[dict[int, str], int]:
"""Load char2ind.txt -> (idx->char mapping, blank_index).
Format: '<char> <index>' per line. Special: <space>=space, <blank>=CTC blank."""
idx2char = {}
blank_idx = 0
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.rstrip("\n")
if not line:
continue
sp = line.rfind(" ")
if sp <= 0:
continue
char, idx = line[:sp], int(line[sp + 1 :])
if char == "<blank>":
blank_idx = idx
elif char == "<space>":
idx2char[idx] = " "
else:
idx2char[idx] = char
return idx2char, blank_idx
def ctc_greedy_decode(logprobs: np.ndarray, idx2char: dict, blank_idx: int) -> str:
"""CTC greedy decode: argmax per timestep, merge repeats, remove blanks."""
if logprobs.ndim == 3:
logprobs = logprobs[:, 0, :] if logprobs.shape[1] == 1 else logprobs[0]
indices = np.argmax(logprobs, axis=-1)
chars = []
prev = -1
for idx in indices:
if idx != prev and idx != blank_idx:
chars.append(idx2char.get(int(idx), f"[{idx}]"))
prev = idx
return "".join(chars)
def preprocess_for_recognizer(img: Image.Image, target_h: int = 60) -> tuple[np.ndarray, np.ndarray]:
"""Preprocess image for recognizer model.
Returns (data[1,3,H,W], seq_lengths[1])."""
w, h = img.size
scale = target_h / h
new_w = max(int(w * scale), 32)
new_w = (new_w + 3) // 4 * 4 # align to 4
img_r = img.resize((new_w, target_h), Image.LANCZOS)
arr = np.array(img_r, dtype=np.float32) / 255.0 # KEY: just /255, no ImageNet
data = arr.transpose(2, 0, 1)[np.newaxis] # HWC -> NCHW
seq_lengths = np.array([new_w // 4], dtype=np.int32)
return data, seq_lengths
def run_recognizer(
model_path: str, data: np.ndarray, seq_lengths: np.ndarray,
idx2char: dict, blank_idx: int
) -> tuple[str, float]:
"""Run recognizer and decode text. Returns (text, avg_confidence)."""
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logprobs = sess.run(None, {"data": data, "seq_lengths": seq_lengths})[0]
text = ctc_greedy_decode(logprobs, idx2char, blank_idx)
probs = np.exp(logprobs[:, 0, :])
max_prob = probs.max(axis=-1)
non_blank = np.argmax(logprobs[:, 0, :], axis=-1) != blank_idx
conf = max_prob[non_blank].mean() if non_blank.any() else 0.0
return text, float(conf)
def find_model_file(model_idx: int) -> str | None:
"""Find ONNX model file by index."""
matches = list(MODELS_DIR.glob(f"model_{model_idx:02d}_*"))
return str(matches[0]) if matches else None
def main():
image_path = sys.argv[1] if len(sys.argv) > 1 else "image.png"
print(f"{'=' * 70}")
print(f" ONEOCR Cross-Platform ONNX Inference Test")
print(f" Image: {image_path}")
print(f"{'=' * 70}\n")
img = Image.open(image_path).convert("RGB")
w, h = img.size
print(f" Image size: {w}x{h}\n")
# ── Test 1: Detector ─────────────────────────────────────────────────
print("[1/3] DETECTOR (model_00)")
det_path = find_model_file(0)
if det_path:
try:
sess = ort.InferenceSession(det_path, providers=["CPUExecutionProvider"])
scale = 800 / max(h, w)
dh = (int(h * scale) + 31) // 32 * 32
dw = (int(w * scale) + 31) // 32 * 32
img_d = img.resize((dw, dh), Image.LANCZOS)
arr_d = np.array(img_d, dtype=np.float32)
arr_d = arr_d[:, :, ::-1] - [102.9801, 115.9465, 122.7717]
data_d = arr_d.transpose(2, 0, 1)[np.newaxis].astype(np.float32)
im_info = np.array([[dh, dw, scale]], dtype=np.float32)
outputs = sess.run(None, {"data": data_d, "im_info": im_info})
scores = 1.0 / (1.0 + np.exp(-outputs[0]))
max_score = scores.max()
hot = (scores > 0.5).sum()
print(f" FPN2 scores: shape={scores.shape}, max={max_score:.3f}, hot_pixels={hot}")
print(f" OK - detector runs on onnxruntime\n")
except Exception as e:
print(f" ERROR: {e}\n")
# ── Test 2: All Recognizers ──────────────────────────────────────────
print("[2/3] RECOGNIZERS (model_02..10) on full image")
data, seq_lengths = preprocess_for_recognizer(img)
print(f" Input: {data.shape}, seq_lengths={seq_lengths}\n")
results = []
for model_idx in range(2, 11):
info = MODEL_REGISTRY.get(model_idx)
if not info:
continue
_, script, char_file = info
model_path = find_model_file(model_idx)
char_path = CONFIG_DIR / char_file if char_file else None
if not model_path or not char_path or not char_path.exists():
print(f" model_{model_idx:02d} ({script:12s}): SKIP - files missing")
continue
try:
idx2char, blank_idx = load_char_map(str(char_path))
text, conf = run_recognizer(model_path, data, seq_lengths, idx2char, blank_idx)
mark = "OK" if conf > 0.8 else "LOW" if conf > 0.5 else "--"
print(f" model_{model_idx:02d} ({script:12s}): [{mark}] conf={conf:.3f} \"{text}\"")
results.append((model_idx, script, text, conf))
except Exception as e:
print(f" model_{model_idx:02d} ({script:12s}): ERR {e}")
# ── Best result ──────────────────────────────────────────────────────
print(f"\n[3/3] RESULT")
if results:
best = max(results, key=lambda x: x[3])
print(f" Best: {best[1]} (model_{best[0]:02d}), conf={best[3]:.1%}")
print(f" Text: \"{best[2]}\"")
# ── Summary ──────────────────────────────────────────────────────────
print(f"""
{'=' * 70}
ONEOCR MODEL SUMMARY
{'=' * 70}
Extracted: 34 ONNX models from oneocr.onemodel
Cross-platform (onnxruntime):
model_00 Detector (PixelLink FPN, 11MB)
model_01 Script ID predictor (3.3MB)
model_02..10 Character recognizers (1.7-13MB each)
= 12 models, core OCR pipeline works on Linux/Mac/Windows
Needs custom ops (com.microsoft.oneocr):
model_11..32 Language models (26-28KB each)
model_33 Line layout predictor (857KB)
= 23 models use DynamicQuantizeLSTM custom op
Preprocessing: RGB -> resize H=60 -> /255 -> NCHW float32
Postprocessing: CTC greedy decode with char2ind mapping
Files: oneocr_extracted/onnx_models/ (34 .onnx)
oneocr_extracted/config_data/ (33 configs)
""")
if __name__ == "__main__":
main()