oneocr / ocr /engine_onnx.py
OneOCR Dev
OneOCR - reverse engineering complete, ONNX pipeline 53% match rate
ce847d4
"""
Cross-platform OneOCR engine — pure Python/ONNX implementation.
Reimplements the full OneOCR DLL pipeline using extracted ONNX models:
1. Detector (model_00): PixelLink FPN text detection → bounding boxes
2. ScriptID (model_01): Script/handwriting/flip classification
3. Recognizers (02-10): Per-script CTC character recognition
4. Line grouping: Heuristic grouping of words into lines
Preprocessing:
- Detector: BGR, mean-subtracted [102.98, 115.95, 122.77], NCHW
- Recognizers: RGB / 255.0, height=60px, NCHW
- ScriptID: Same as recognizers
Output format matches the DLL: OcrResult with BoundingRect, OcrLine, OcrWord.
Usage:
from ocr.engine_onnx import OcrEngineOnnx
engine = OcrEngineOnnx()
result = engine.recognize_pil(pil_image)
print(result.text)
"""
from __future__ import annotations
import math
import os
from pathlib import Path
from typing import TYPE_CHECKING
import cv2
import numpy as np
import onnxruntime as ort
from ocr.models import BoundingRect, OcrLine, OcrResult, OcrWord
if TYPE_CHECKING:
from PIL import Image
# ─── Constants ───────────────────────────────────────────────────────────────
# BGR mean values for detector (ImageNet-style)
_DET_MEAN = np.array([102.9801, 115.9465, 122.7717], dtype=np.float32)
# Detector thresholds (from DLL protobuf config — segment_conf_threshold)
_PIXEL_SCORE_THRESH = 0.7 # text/non-text pixel threshold (DLL: field 8 = 0.7)
_LINK_SCORE_THRESH = 0.5 # neighbor link threshold
_MIN_AREA = 50 # minimum text region area (pixels)
_MIN_HEIGHT = 5 # minimum text region height
_MIN_WIDTH = 4 # minimum text region width
# Detector scaling (Faster R-CNN / FPN style — short side target, long side cap)
_DET_TARGET_SHORT = 800 # scale shortest side to this
_DET_MAX_LONG = 1333 # cap longest side (Faster R-CNN standard)
# NMS thresholds (from DLL protobuf config)
_NMS_IOU_THRESH = 0.2 # textline_nms_threshold (DLL: field 10 = 0.2)
# Recognizer settings
_REC_TARGET_H = 60 # target height for recognizer input
_REC_MIN_WIDTH = 32 # minimum width after resize
# Line grouping
_LINE_IOU_Y_THRESH = 0.5 # Y-overlap threshold for same-line grouping
_LINE_MERGE_GAP = 2.0 # max gap between words on same line (as ratio of avg char height)
# Script names for ScriptID output (10 classes)
SCRIPT_NAMES = [
"Latin", "CJK", "Arabic", "Cyrillic", "Devanagari",
"Greek", "Hebrew", "Thai", "Tamil", "Unknown"
]
# Model index → (role, script, char2ind_file, rnn_info_file)
MODEL_REGISTRY: dict[int, tuple[str, str, str | None]] = {
0: ("detector", "universal", None),
1: ("script_id", "universal", None),
2: ("recognizer", "Latin", "chunk_37_char2ind.char2ind.txt"),
3: ("recognizer", "CJK", "chunk_40_char2ind.char2ind.txt"),
4: ("recognizer", "Arabic", "chunk_43_char2ind.char2ind.txt"),
5: ("recognizer", "Cyrillic", "chunk_47_char2ind.char2ind.txt"),
6: ("recognizer", "Devanagari", "chunk_50_char2ind.char2ind.txt"),
7: ("recognizer", "Greek", "chunk_53_char2ind.char2ind.txt"),
8: ("recognizer", "Hebrew", "chunk_57_char2ind.char2ind.txt"),
9: ("recognizer", "Tamil", "chunk_61_char2ind.char2ind.txt"),
10: ("recognizer", "Thai", "chunk_64_char2ind.char2ind.txt"),
}
# Script name → recognizer model index
SCRIPT_TO_MODEL: dict[str, int] = {
"Latin": 2, "CJK": 3, "Arabic": 4, "Cyrillic": 5,
"Devanagari": 6, "Greek": 7, "Hebrew": 8, "Tamil": 9, "Thai": 10,
}
# ─── Helper: Character map ──────────────────────────────────────────────────
def load_char_map(path: str | Path) -> tuple[dict[int, str], int]:
"""Load char2ind.txt → (idx→char map, blank_index).
Format: '<char> <index>' per line.
Special tokens: <space>=space char, <blank>=CTC blank.
"""
idx2char: dict[int, str] = {}
blank_idx = 0
with open(path, "r", encoding="utf-8") as f:
for raw_line in f:
line = raw_line.rstrip("\n")
if not line:
continue
sp = line.rfind(" ")
if sp <= 0:
continue
char_str, idx_str = line[:sp], line[sp + 1:]
idx = int(idx_str)
if char_str == "<blank>":
blank_idx = idx
elif char_str == "<space>":
idx2char[idx] = " "
else:
idx2char[idx] = char_str
return idx2char, blank_idx
# ─── PixelLink Post-Processing ──────────────────────────────────────────────
# 8-connected neighbors: (dy, dx) for N, NE, E, SE, S, SW, W, NW
_NEIGHBOR_OFFSETS = [(-1, 0), (-1, 1), (0, 1), (1, 1),
(1, 0), (1, -1), (0, -1), (-1, -1)]
class _UnionFind:
"""Union-Find (Disjoint Set) for connected component labeling."""
__slots__ = ("parent", "rank")
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x: int) -> int:
while self.parent[x] != x:
self.parent[x] = self.parent[self.parent[x]]
x = self.parent[x]
return x
def union(self, a: int, b: int) -> None:
ra, rb = self.find(a), self.find(b)
if ra == rb:
return
if self.rank[ra] < self.rank[rb]:
ra, rb = rb, ra
self.parent[rb] = ra
if self.rank[ra] == self.rank[rb]:
self.rank[ra] += 1
def _pixellink_decode(
pixel_scores: np.ndarray,
link_scores: np.ndarray,
bbox_deltas: np.ndarray,
stride: int,
pixel_thresh: float = _PIXEL_SCORE_THRESH,
link_thresh: float = _LINK_SCORE_THRESH,
min_area: float = _MIN_AREA,
min_component_pixels: int = 3,
) -> list[np.ndarray]:
"""Decode PixelLink outputs into oriented bounding boxes.
Uses connected-component labeling with Union-Find for linked text pixels,
then refines box positions using bbox_deltas regression (matching DLL behavior).
Each text pixel predicts 4 corner offsets (as fraction of stride) via
bbox_deltas[8, H, W] = [TL.x, TL.y, TR.x, TR.y, BR.x, BR.y, BL.x, BL.y].
The actual corner position for pixel (r, c) is:
corner = (pixel_coord + delta) * stride
For a connected component, the bounding box corners are computed by taking
the extremes of all per-pixel predictions (min TL, max BR).
Args:
pixel_scores: [H, W] text/non-text scores (already sigmoid'd)
link_scores: [8, H, W] neighbor link scores (already sigmoid'd)
bbox_deltas: [8, H, W] corner offsets — 4 corners × 2 coords (x, y)
stride: FPN stride (4, 8, or 16)
min_area: minimum box area in detector-image pixels
min_component_pixels: minimum number of pixels in a connected component
Returns:
List of (4, 2) arrays — quadrilateral corners in detector-image coordinates.
"""
h, w = pixel_scores.shape
# Step 1: Threshold pixels
text_mask = pixel_scores > pixel_thresh
text_pixels = np.argwhere(text_mask) # (N, 2) — (row, col) pairs
if len(text_pixels) == 0:
return []
# Step 2: Build pixel index map for quick lookup
pixel_map = np.full((h, w), -1, dtype=np.int32)
for i, (r, c) in enumerate(text_pixels):
pixel_map[r, c] = i
# Step 3: Union-Find to group linked pixels
uf = _UnionFind(len(text_pixels))
for i, (r, c) in enumerate(text_pixels):
for ni, (dy, dx) in enumerate(_NEIGHBOR_OFFSETS):
nr, nc = r + dy, c + dx
if 0 <= nr < h and 0 <= nc < w:
j = pixel_map[nr, nc]
if j >= 0 and link_scores[ni, r, c] > link_thresh:
uf.union(i, j)
# Step 4: Group pixels by component
components: dict[int, list[int]] = {}
for i in range(len(text_pixels)):
root = uf.find(i)
if root not in components:
components[root] = []
components[root].append(i)
# Step 5: For each component, compute bbox using delta regression
quads: list[np.ndarray] = []
for indices in components.values():
if len(indices) < min_component_pixels:
continue
# Compute per-pixel corner predictions using bbox_deltas
# Each pixel at (r, c) predicts 4 corners:
# TL = ((c + d0) * stride, (r + d1) * stride)
# TR = ((c + d2) * stride, (r + d3) * stride)
# BR = ((c + d4) * stride, (r + d5) * stride)
# BL = ((c + d6) * stride, (r + d7) * stride)
tl_x_min = float("inf")
tl_y_min = float("inf")
br_x_max = float("-inf")
br_y_max = float("-inf")
for idx in indices:
r, c = int(text_pixels[idx][0]), int(text_pixels[idx][1])
# TL corner
tl_x = (c + float(bbox_deltas[0, r, c])) * stride
tl_y = (r + float(bbox_deltas[1, r, c])) * stride
# TR corner
tr_x = (c + float(bbox_deltas[2, r, c])) * stride
# BR corner
br_x = (c + float(bbox_deltas[4, r, c])) * stride
br_y = (r + float(bbox_deltas[5, r, c])) * stride
# BL corner
bl_y = (r + float(bbox_deltas[7, r, c])) * stride
tl_x_min = min(tl_x_min, tl_x)
tl_y_min = min(tl_y_min, tl_y, (r + float(bbox_deltas[3, r, c])) * stride)
br_x_max = max(br_x_max, br_x, tr_x)
br_y_max = max(br_y_max, br_y, bl_y)
# Clamp to positive coordinates
x1 = max(0.0, tl_x_min)
y1 = max(0.0, tl_y_min)
x2 = br_x_max
y2 = br_y_max
box_w = x2 - x1
box_h = y2 - y1
area = box_w * box_h
if area < min_area:
continue
if box_h < _MIN_HEIGHT:
continue
if box_w < _MIN_WIDTH:
continue
# Create axis-aligned quad (TL, TR, BR, BL)
quad = np.array([
[x1, y1],
[x2, y1],
[x2, y2],
[x1, y2],
], dtype=np.float32)
quads.append(quad)
return quads
def _order_corners(pts: np.ndarray) -> np.ndarray:
"""Order 4 corners as: top-left, top-right, bottom-right, bottom-left."""
# Sort by y first, then by x
s = pts.sum(axis=1)
d = np.diff(pts, axis=1).ravel()
ordered = np.zeros((4, 2), dtype=np.float32)
ordered[0] = pts[np.argmin(s)] # top-left: smallest sum
ordered[2] = pts[np.argmax(s)] # bottom-right: largest sum
ordered[1] = pts[np.argmin(d)] # top-right: smallest diff
ordered[3] = pts[np.argmax(d)] # bottom-left: largest diff
return ordered
def _nms_quads(quads: list[np.ndarray], iou_thresh: float = 0.3) -> list[np.ndarray]:
"""Non-maximum suppression on quadrilateral boxes using contour IoU."""
if len(quads) <= 1:
return quads
# Sort by area (largest first)
areas = [cv2.contourArea(q) for q in quads]
order = np.argsort(areas)[::-1]
keep: list[np.ndarray] = []
used = set()
for i in order:
if i in used:
continue
keep.append(quads[i])
used.add(i)
for j in order:
if j in used:
continue
# Compute IoU between quads[i] and quads[j]
iou = _quad_iou(quads[i], quads[j])
if iou > iou_thresh:
# Merge: keep the larger one
used.add(j)
return keep
def _quad_iou(q1: np.ndarray, q2: np.ndarray) -> float:
"""Compute IoU between two quadrilaterals."""
try:
ret, region = cv2.intersectConvexConvex(
q1.astype(np.float32), q2.astype(np.float32)
)
if ret <= 0:
return 0.0
inter = cv2.contourArea(region)
a1 = cv2.contourArea(q1)
a2 = cv2.contourArea(q2)
union = a1 + a2 - inter
return inter / union if union > 0 else 0.0
except Exception:
return 0.0
# ─── CTC Decoding ───────────────────────────────────────────────────────────
def ctc_greedy_decode(
logprobs: np.ndarray,
idx2char: dict[int, str],
blank_idx: int,
) -> tuple[str, float, list[float]]:
"""CTC greedy decode: argmax per timestep, merge repeats, remove blanks.
Returns (decoded_text, average_confidence, per_char_confidences).
"""
if logprobs.ndim == 3:
logprobs = logprobs[:, 0, :]
indices = np.argmax(logprobs, axis=-1)
probs = np.exp(logprobs)
max_probs = probs[np.arange(len(indices)), indices]
chars: list[str] = []
char_probs: list[float] = []
prev = -1
for t, idx in enumerate(indices):
if idx != prev and idx != blank_idx:
chars.append(idx2char.get(int(idx), f"[{idx}]"))
char_probs.append(float(max_probs[t]))
prev = idx
text = "".join(chars)
confidence = float(np.mean(char_probs)) if char_probs else 0.0
return text, confidence, char_probs
# ─── Line Grouping ──────────────────────────────────────────────────────────
def _group_words_into_lines(
words: list[tuple[str, np.ndarray, float]],
) -> list[list[tuple[str, np.ndarray, float]]]:
"""Group detected words into lines based on Y-overlap.
Args:
words: List of (text, quad[4,2], confidence).
Returns:
List of line groups, each a list of (text, quad, conf) tuples.
"""
if not words:
return []
# Sort by center Y
words_with_cy = []
for w in words:
_, quad, _ = w
cy = quad[:, 1].mean()
words_with_cy.append((cy, w))
words_with_cy.sort(key=lambda x: x[0])
lines: list[list[tuple[str, np.ndarray, float]]] = []
used = set()
for i, (cy_i, w_i) in enumerate(words_with_cy):
if i in used:
continue
_, qi, _ = w_i
y_min_i = qi[:, 1].min()
y_max_i = qi[:, 1].max()
h_i = y_max_i - y_min_i
line = [w_i]
used.add(i)
for j in range(i + 1, len(words_with_cy)):
if j in used:
continue
_, w_j = words_with_cy[j]
_, qj, _ = w_j
y_min_j = qj[:, 1].min()
y_max_j = qj[:, 1].max()
# Check Y overlap
overlap = min(y_max_i, y_max_j) - max(y_min_i, y_min_j)
min_height = min(y_max_i - y_min_i, y_max_j - y_min_j)
if min_height > 0 and overlap / min_height > _LINE_IOU_Y_THRESH:
line.append(w_j)
used.add(j)
# Expand y range
y_min_i = min(y_min_i, y_min_j)
y_max_i = max(y_max_i, y_max_j)
# Sort line words by X position (left to right)
line.sort(key=lambda w: w[1][:, 0].min())
lines.append(line)
return lines
# ─── Image Angle Detection ──────────────────────────────────────────────────
def _estimate_image_angle(quads: list[np.ndarray]) -> float:
"""Estimate overall text angle from detected quads.
Uses the average angle of the top edges of all boxes.
"""
if not quads:
return 0.0
angles = []
for q in quads:
# Top edge: q[0] → q[1]
dx = q[1][0] - q[0][0]
dy = q[1][1] - q[0][1]
if abs(dx) < 1:
continue
angle = math.degrees(math.atan2(dy, dx))
angles.append(angle)
if not angles:
return 0.0
return float(np.median(angles))
# ═══════════════════════════════════════════════════════════════════════════
# Main Engine
# ═══════════════════════════════════════════════════════════════════════════
class OcrEngineOnnx:
"""Cross-platform OCR engine using extracted ONNX models.
Provides the same API as OcrEngine (DLL version) but runs on any OS.
Args:
models_dir: Path to extracted ONNX models directory.
config_dir: Path to extracted config data directory.
providers: ONNX Runtime providers (default: CPU only).
"""
def __init__(
self,
models_dir: str | Path | None = None,
config_dir: str | Path | None = None,
providers: list[str] | None = None,
) -> None:
base = Path(__file__).resolve().parent.parent
self._models_dir = Path(models_dir) if models_dir else base / "oneocr_extracted" / "onnx_models"
self._config_dir = Path(config_dir) if config_dir else base / "oneocr_extracted" / "config_data"
self._unlocked_dir = self._models_dir.parent / "onnx_models_unlocked"
self._providers = providers or ["CPUExecutionProvider"]
# Optimized session options (matching DLL's ORT configuration)
self._sess_opts = ort.SessionOptions()
self._sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._sess_opts.intra_op_num_threads = max(1, (os.cpu_count() or 4) // 2)
self._sess_opts.inter_op_num_threads = 2
self._sess_opts.enable_mem_pattern = True
self._sess_opts.enable_cpu_mem_arena = True
# Lazy-loaded sessions
self._detector: ort.InferenceSession | None = None
self._script_id: ort.InferenceSession | None = None
self._recognizers: dict[int, ort.InferenceSession] = {}
self._char_maps: dict[int, tuple[dict[int, str], int]] = {}
self._line_layout: ort.InferenceSession | None = None
# Validate paths exist
if not self._models_dir.exists():
raise FileNotFoundError(f"Models directory not found: {self._models_dir}")
if not self._config_dir.exists():
raise FileNotFoundError(f"Config directory not found: {self._config_dir}")
# ─── Public API ───────────────────────────────────────────────────────
def recognize_pil(self, image: "Image.Image") -> OcrResult:
"""Run OCR on a PIL Image.
Args:
image: PIL Image (any mode — will be converted to RGB).
Returns:
OcrResult with recognized text, lines, words, bounding boxes,
confidence values, and detected text angle.
"""
if any(x < 10 or x > 10000 for x in image.size):
return OcrResult(error="Unsupported image size (must be 10-10000px)")
img_rgb = image.convert("RGB")
img_arr = np.array(img_rgb)
try:
return self._run_pipeline(img_arr)
except Exception as e:
return OcrResult(error=f"Pipeline error: {e}")
def recognize_bytes(self, image_bytes: bytes) -> OcrResult:
"""Run OCR on raw image bytes (PNG/JPEG/etc)."""
from io import BytesIO
from PIL import Image as PILImage
img = PILImage.open(BytesIO(image_bytes))
return self.recognize_pil(img)
def recognize_numpy(self, img_rgb: np.ndarray) -> OcrResult:
"""Run OCR on a numpy array (H, W, 3) in RGB format."""
if img_rgb.ndim != 3 or img_rgb.shape[2] != 3:
return OcrResult(error="Expected (H, W, 3) RGB array")
try:
return self._run_pipeline(img_rgb)
except Exception as e:
return OcrResult(error=f"Pipeline error: {e}")
# ─── Pipeline ─────────────────────────────────────────────────────────
def _run_pipeline(self, img_rgb: np.ndarray) -> OcrResult:
"""Full OCR pipeline: detect → crop → scriptID → recognize → group."""
h, w = img_rgb.shape[:2]
# Step 1: Detect text regions
quads, scale = self._detect(img_rgb)
if not quads:
return OcrResult(text="", text_angle=0.0, lines=[])
# Step 2: Estimate image angle
text_angle = _estimate_image_angle(quads)
# Step 3: For each detected region (line), crop and recognize
line_results: list[tuple[str, np.ndarray, float, list[float]]] = []
for quad in quads:
# Crop text region from original image
crop = self._crop_quad(img_rgb, quad)
if crop is None or crop.shape[0] < 5 or crop.shape[1] < 5:
continue
# Detect vertical text: if crop is much taller than wide, rotate 90° CCW
ch, cw = crop.shape[:2]
is_vertical = ch > cw * 2
rec_crop = crop
if is_vertical:
rec_crop = cv2.rotate(crop, cv2.ROTATE_90_COUNTERCLOCKWISE)
# ScriptID (on properly oriented crop)
script_idx = self._identify_script(rec_crop)
script_name = SCRIPT_NAMES[script_idx] if script_idx < len(SCRIPT_NAMES) else "Latin"
# Map to recognizer
model_idx = SCRIPT_TO_MODEL.get(script_name, 2) # default Latin
# Recognize full line
text, conf, char_confs = self._recognize(rec_crop, model_idx)
# For vertical text fallback: if confidence is low, try CJK recognizer
if is_vertical and conf < 0.7 and model_idx != 3:
text_cjk, conf_cjk, char_confs_cjk = self._recognize(rec_crop, 3)
if conf_cjk > conf:
text, conf, char_confs = text_cjk, conf_cjk, char_confs_cjk
if text.strip():
# Graded confidence filter.
# Short noise detections from FPN2 (single chars like "1", "C")
# typically have lower confidence than genuine text like "I", "A".
text_stripped = text.strip()
n_chars = len(text_stripped)
if n_chars <= 1 and conf < 0.35:
continue # single char needs some confidence
elif conf < 0.3:
continue # very low confidence = noise
line_results.append((text, quad, conf, char_confs))
if not line_results:
return OcrResult(text="", text_angle=text_angle, lines=[])
# Step 4: Build OcrResult — split recognized text into words
lines: list[OcrLine] = []
for line_text, quad, conf, char_confs in line_results:
# Split text by spaces to get words
word_texts = line_text.split()
if not word_texts:
continue
# Estimate per-word bounding boxes and confidence by distributing
# quad width and char confidences proportionally to words
words = self._split_into_words(word_texts, quad, conf, char_confs)
line_bbox = BoundingRect(
x1=float(quad[0][0]), y1=float(quad[0][1]),
x2=float(quad[1][0]), y2=float(quad[1][1]),
x3=float(quad[2][0]), y3=float(quad[2][1]),
x4=float(quad[3][0]), y4=float(quad[3][1]),
)
full_line_text = " ".join(w.text for w in words)
lines.append(OcrLine(text=full_line_text, bounding_rect=line_bbox, words=words))
# Sort lines top-to-bottom by center Y coordinate
lines.sort(key=lambda l: (
(l.bounding_rect.y1 + l.bounding_rect.y3) / 2
if l.bounding_rect else 0
))
full_text = "\n".join(line.text for line in lines)
return OcrResult(text=full_text, text_angle=text_angle, lines=lines)
# ─── Detection ────────────────────────────────────────────────────────
def _detect(self, img_rgb: np.ndarray) -> tuple[list[np.ndarray], float]:
"""Run PixelLink detector and decode bounding boxes.
Scaling: FPN-style — scale shortest side to 800, cap longest at 1333.
Two-phase detection:
- Phase 1: FPN3 (stride=8) + FPN4 (stride=16) — primary detections
- Phase 2: FPN2 (stride=4) — supplementary for small text ("I", "...")
Only keeps novel FPN2 detections that don't overlap with primary.
NMS: IoU threshold 0.2 (from DLL protobuf config).
Returns (list_of_quads, scale_factor).
"""
h, w = img_rgb.shape[:2]
# FPN-style scaling: target short side = 800, cap long side at 1333
short_side = min(h, w)
long_side = max(h, w)
scale = _DET_TARGET_SHORT / short_side
if long_side * scale > _DET_MAX_LONG:
scale = _DET_MAX_LONG / long_side
scale = min(scale, 6.0)
dh = (int(h * scale) + 31) // 32 * 32
dw = (int(w * scale) + 31) // 32 * 32
img_resized = cv2.resize(img_rgb, (dw, dh), interpolation=cv2.INTER_LINEAR)
# BGR + mean subtraction (ImageNet-style, matching DLL)
img_bgr = img_resized[:, :, ::-1].astype(np.float32) - _DET_MEAN
data = img_bgr.transpose(2, 0, 1)[np.newaxis]
im_info = np.array([[dh, dw, scale]], dtype=np.float32)
# Run detector
sess = self._get_detector()
outputs = sess.run(None, {"data": data, "im_info": im_info})
output_names = [o.name for o in sess.get_outputs()]
out_dict = dict(zip(output_names, outputs))
# ── Phase 1: Primary detections from FPN3 + FPN4 ──
primary_quads: list[np.ndarray] = []
for level, stride in [("fpn3", 8), ("fpn4", 16)]:
min_area_scaled = _MIN_AREA * (scale ** 2)
for orientation in ("hori", "vert"):
scores_key = f"scores_{orientation}_{level}"
links_key = f"link_scores_{orientation}_{level}"
deltas_key = f"bbox_deltas_{orientation}_{level}"
if scores_key not in out_dict:
continue
pixel_scores = out_dict[scores_key][0, 0]
if orientation == "vert" and pixel_scores.max() <= _PIXEL_SCORE_THRESH:
continue
link_scores = out_dict[links_key][0]
bbox_deltas = out_dict[deltas_key][0]
quads = _pixellink_decode(
pixel_scores, link_scores, bbox_deltas, stride,
pixel_thresh=_PIXEL_SCORE_THRESH,
min_area=min_area_scaled,
)
primary_quads.extend(quads)
primary_quads = _nms_quads(primary_quads, iou_thresh=_NMS_IOU_THRESH)
# ── Phase 2: FPN2 supplementary detections ──
# Higher threshold (0.85) to reduce false positives from panel borders.
# Only keep novel detections that don't overlap with primary.
fpn2_quads: list[np.ndarray] = []
min_area_scaled = _MIN_AREA * (scale ** 2)
for orientation in ("hori", "vert"):
scores_key = f"scores_{orientation}_fpn2"
links_key = f"link_scores_{orientation}_fpn2"
deltas_key = f"bbox_deltas_{orientation}_fpn2"
if scores_key not in out_dict:
continue
pixel_scores = out_dict[scores_key][0, 0]
if pixel_scores.max() <= 0.85:
continue
link_scores = out_dict[links_key][0]
bbox_deltas = out_dict[deltas_key][0]
quads = _pixellink_decode(
pixel_scores, link_scores, bbox_deltas, 4,
pixel_thresh=0.85,
min_area=min_area_scaled,
min_component_pixels=5,
)
for q in quads:
# Only keep if not overlapping with primary detections
overlaps = any(_quad_iou(q, p) > 0.1 for p in primary_quads)
if not overlaps:
fpn2_quads.append(q)
# Combine and final NMS
all_quads = primary_quads + fpn2_quads
if fpn2_quads:
all_quads = _nms_quads(all_quads, iou_thresh=_NMS_IOU_THRESH)
# Scale quads back to original image coordinates
for i in range(len(all_quads)):
all_quads[i] = all_quads[i] / scale
return all_quads, scale
# ─── Script Identification ────────────────────────────────────────────
def _identify_script(self, crop_rgb: np.ndarray) -> int:
"""Identify script of a cropped text region.
Returns script index (0=Latin, 1=CJK, ..., 9=Unknown).
"""
sess = self._get_script_id()
# Preprocess: RGB -> height=60 -> /255 -> NCHW
data = self._preprocess_recognizer(crop_rgb)
outputs = sess.run(None, {"data": data})
# Output: script_id_score [1, 1, 10]
script_scores = outputs[3] # script_id_score
script_idx = int(np.argmax(script_scores.flatten()[:10]))
return script_idx
# ─── Recognition ──────────────────────────────────────────────────────
def _recognize(
self, crop_rgb: np.ndarray, model_idx: int
) -> tuple[str, float, list[float]]:
"""Recognize text in a cropped region using the specified model.
Returns (text, confidence, per_char_confidences).
"""
sess = self._get_recognizer(model_idx)
idx2char, blank_idx = self._get_char_map(model_idx)
data = self._preprocess_recognizer(crop_rgb)
h, w = data.shape[2], data.shape[3]
seq_lengths = np.array([w // 4], dtype=np.int32)
logprobs = sess.run(None, {"data": data, "seq_lengths": seq_lengths})[0]
text, conf, char_confs = ctc_greedy_decode(logprobs, idx2char, blank_idx)
return text, conf, char_confs
# ─── Word splitting ─────────────────────────────────────────────────
@staticmethod
def _split_into_words(
word_texts: list[str],
quad: np.ndarray,
confidence: float,
char_confs: list[float] | None = None,
) -> list[OcrWord]:
"""Split a line into word-level OcrWord objects with estimated bboxes.
Distributes the line quad proportionally by character count.
Per-word confidence is computed from character-level CTC confidences.
"""
if not word_texts:
return []
# Include spaces in character counting for positioning
total_chars = sum(len(w) for w in word_texts) + len(word_texts) - 1
if total_chars <= 0:
total_chars = 1
# Build per-word confidence from char_confs
word_confidences: list[float] = []
if char_confs and len(char_confs) >= sum(len(w) for w in word_texts):
idx = 0
for word_text in word_texts:
wc = char_confs[idx:idx + len(word_text)]
word_confidences.append(float(np.mean(wc)) if wc else confidence)
idx += len(word_text)
# Skip space character confidence (if present in the list)
if idx < len(char_confs):
idx += 1 # skip space
else:
word_confidences = [confidence] * len(word_texts)
# Interpolate along top edge (q0→q1) and bottom edge (q3→q2)
top_start, top_end = quad[0], quad[1]
bot_start, bot_end = quad[3], quad[2]
words: list[OcrWord] = []
char_pos = 0
for i, word_text in enumerate(word_texts):
t_start = char_pos / total_chars
t_end = (char_pos + len(word_text)) / total_chars
# Interpolate corners
tl = top_start + (top_end - top_start) * t_start
tr = top_start + (top_end - top_start) * t_end
bl = bot_start + (bot_end - bot_start) * t_start
br = bot_start + (bot_end - bot_start) * t_end
bbox = BoundingRect(
x1=float(tl[0]), y1=float(tl[1]),
x2=float(tr[0]), y2=float(tr[1]),
x3=float(br[0]), y3=float(br[1]),
x4=float(bl[0]), y4=float(bl[1]),
)
words.append(OcrWord(
text=word_text,
bounding_rect=bbox,
confidence=word_confidences[i],
))
char_pos += len(word_text) + 1 # +1 for space
return words
# ─── Preprocessing ────────────────────────────────────────────────────
@staticmethod
def _preprocess_recognizer(img_rgb: np.ndarray) -> np.ndarray:
"""Preprocess image for recognizer/scriptID input.
Process: Resize height to 60px → RGB / 255.0 → NCHW float32.
"""
h, w = img_rgb.shape[:2]
target_h = _REC_TARGET_H
scale = target_h / h
new_w = max(int(w * scale), _REC_MIN_WIDTH)
new_w = (new_w + 3) // 4 * 4 # align to 4
resized = cv2.resize(img_rgb, (new_w, target_h), interpolation=cv2.INTER_LINEAR)
data = resized.astype(np.float32) / 255.0
data = data.transpose(2, 0, 1)[np.newaxis] # HWC → NCHW
return data
# ─── Crop ─────────────────────────────────────────────────────────────
@staticmethod
def _crop_quad(
img_rgb: np.ndarray,
quad: np.ndarray,
padding_ratio: float = 0.15,
) -> np.ndarray | None:
"""Crop a text region from the image using the bounding quad.
Uses axis-aligned rectangle crop with padding (matching DLL behavior).
Falls back to perspective transform only for heavily rotated text (>15°).
Args:
img_rgb: Source image (H, W, 3).
quad: 4 corners as (4, 2) array.
padding_ratio: How much to expand (fraction of height).
Returns:
Cropped RGB image or None.
"""
try:
img_h, img_w = img_rgb.shape[:2]
# Check rotation angle of top edge
dx = quad[1][0] - quad[0][0]
dy = quad[1][1] - quad[0][1]
angle = abs(math.atan2(dy, dx)) * 180 / math.pi
# For near-horizontal text (<15°), use simple axis-aligned rectangle
if angle < 15:
# Axis-aligned bounding box from quad
x_min = float(quad[:, 0].min())
x_max = float(quad[:, 0].max())
y_min = float(quad[:, 1].min())
y_max = float(quad[:, 1].max())
box_h = y_max - y_min
box_w = x_max - x_min
if box_h < 3 or box_w < 5:
return None
# Apply padding (use height-proportional padding on ALL sides)
# Minimum 3px padding for very small text regions
pad_h = max(box_h * padding_ratio, 3)
pad_w = max(box_h * 0.25, 3) # wider horizontal padding for stride alignment
y1 = max(0, int(y_min - pad_h))
y2 = min(img_h, int(y_max + pad_h))
x1 = max(0, int(x_min - pad_w))
x2 = min(img_w, int(x_max + pad_w))
if y2 - y1 < 3 or x2 - x1 < 5:
return None
return img_rgb[y1:y2, x1:x2].copy()
# For rotated text, use perspective transform
w1 = np.linalg.norm(quad[1] - quad[0])
w2 = np.linalg.norm(quad[2] - quad[3])
h1 = np.linalg.norm(quad[3] - quad[0])
h2 = np.linalg.norm(quad[2] - quad[1])
target_w = int(max(w1, w2))
target_h = int(max(h1, h2))
if target_w < 5 or target_h < 3:
return None
# Expand the quad by padding
pad = max(h1, h2) * padding_ratio
top_dir = quad[0] - quad[3]
if np.linalg.norm(top_dir) > 0:
top_dir = top_dir / np.linalg.norm(top_dir)
else:
top_dir = np.array([0, -1], dtype=np.float32)
left_dir = quad[0] - quad[1]
if np.linalg.norm(left_dir) > 0:
left_dir = left_dir / np.linalg.norm(left_dir)
else:
left_dir = np.array([-1, 0], dtype=np.float32)
expanded = quad.copy().astype(np.float32)
expanded[0] = quad[0] + top_dir * pad + left_dir * pad * 0.3
expanded[1] = quad[1] + top_dir * pad - left_dir * pad * 0.3
expanded[2] = quad[2] - top_dir * pad - left_dir * pad * 0.3
expanded[3] = quad[3] - top_dir * pad + left_dir * pad * 0.3
expanded[:, 0] = np.clip(expanded[:, 0], 0, img_w - 1)
expanded[:, 1] = np.clip(expanded[:, 1], 0, img_h - 1)
w1 = np.linalg.norm(expanded[1] - expanded[0])
w2 = np.linalg.norm(expanded[2] - expanded[3])
h1 = np.linalg.norm(expanded[3] - expanded[0])
h2 = np.linalg.norm(expanded[2] - expanded[1])
target_w = int(max(w1, w2))
target_h = int(max(h1, h2))
if target_w < 5 or target_h < 3:
return None
dst = np.array([
[0, 0],
[target_w - 1, 0],
[target_w - 1, target_h - 1],
[0, target_h - 1],
], dtype=np.float32)
M = cv2.getPerspectiveTransform(expanded.astype(np.float32), dst)
crop = cv2.warpPerspective(
img_rgb, M, (target_w, target_h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE,
)
return crop
except Exception:
return None
# ─── Model Loading ────────────────────────────────────────────────────
def _find_model(self, model_idx: int) -> Path:
"""Find ONNX model file by index. Checks unlocked dir first for models 11-33."""
if model_idx >= 11 and self._unlocked_dir.exists():
matches = list(self._unlocked_dir.glob(f"model_{model_idx:02d}_*"))
if matches:
return matches[0]
matches = list(self._models_dir.glob(f"model_{model_idx:02d}_*"))
if not matches:
raise FileNotFoundError(f"Model file not found for index {model_idx}")
return matches[0]
def _get_detector(self) -> ort.InferenceSession:
"""Get or create detector session."""
if self._detector is None:
path = self._find_model(0)
self._detector = ort.InferenceSession(
str(path), sess_options=self._sess_opts, providers=self._providers
)
return self._detector
def _get_script_id(self) -> ort.InferenceSession:
"""Get or create ScriptID session."""
if self._script_id is None:
path = self._find_model(1)
self._script_id = ort.InferenceSession(
str(path), sess_options=self._sess_opts, providers=self._providers
)
return self._script_id
def _get_recognizer(self, model_idx: int) -> ort.InferenceSession:
"""Get or create recognizer session."""
if model_idx not in self._recognizers:
path = self._find_model(model_idx)
self._recognizers[model_idx] = ort.InferenceSession(
str(path), sess_options=self._sess_opts, providers=self._providers
)
return self._recognizers[model_idx]
def _get_char_map(self, model_idx: int) -> tuple[dict[int, str], int]:
"""Get or load character map for model."""
if model_idx not in self._char_maps:
info = MODEL_REGISTRY.get(model_idx)
if not info or not info[2]:
raise ValueError(f"No char2ind file for model {model_idx}")
char_path = self._config_dir / info[2]
self._char_maps[model_idx] = load_char_map(char_path)
return self._char_maps[model_idx]
def _get_line_layout(self) -> ort.InferenceSession | None:
"""Get or create LineLayout session (model 33). Returns None if unavailable."""
if self._line_layout is None:
try:
path = self._find_model(33)
self._line_layout = ort.InferenceSession(
str(path), sess_options=self._sess_opts, providers=self._providers
)
except FileNotFoundError:
return None
return self._line_layout
def _run_line_layout(self, crop_rgb: np.ndarray) -> float | None:
"""Run LineLayout model to get line boundary score.
Args:
crop_rgb: Cropped text line image.
Returns:
Line layout score (higher = more confident this is a complete line),
or None if LineLayout model unavailable.
"""
sess = self._get_line_layout()
if sess is None:
return None
try:
data = self._preprocess_recognizer(crop_rgb)
outputs = sess.run(None, {"data": data})
# output[1] = line_layout_score [1, 1, 1]
score = float(outputs[1].flatten()[0])
return score
except Exception:
return None