""" Cross-platform OneOCR engine — pure Python/ONNX implementation. Reimplements the full OneOCR DLL pipeline using extracted ONNX models: 1. Detector (model_00): PixelLink FPN text detection → bounding boxes 2. ScriptID (model_01): Script/handwriting/flip classification 3. Recognizers (02-10): Per-script CTC character recognition 4. Line grouping: Heuristic grouping of words into lines Preprocessing: - Detector: BGR, mean-subtracted [102.98, 115.95, 122.77], NCHW - Recognizers: RGB / 255.0, height=60px, NCHW - ScriptID: Same as recognizers Output format matches the DLL: OcrResult with BoundingRect, OcrLine, OcrWord. Usage: from ocr.engine_onnx import OcrEngineOnnx engine = OcrEngineOnnx() result = engine.recognize_pil(pil_image) print(result.text) """ from __future__ import annotations import math import os from pathlib import Path from typing import TYPE_CHECKING import cv2 import numpy as np import onnxruntime as ort from ocr.models import BoundingRect, OcrLine, OcrResult, OcrWord if TYPE_CHECKING: from PIL import Image # ─── Constants ─────────────────────────────────────────────────────────────── # BGR mean values for detector (ImageNet-style) _DET_MEAN = np.array([102.9801, 115.9465, 122.7717], dtype=np.float32) # Detector thresholds (from DLL protobuf config — segment_conf_threshold) _PIXEL_SCORE_THRESH = 0.7 # text/non-text pixel threshold (DLL: field 8 = 0.7) _LINK_SCORE_THRESH = 0.5 # neighbor link threshold _MIN_AREA = 50 # minimum text region area (pixels) _MIN_HEIGHT = 5 # minimum text region height _MIN_WIDTH = 4 # minimum text region width # Detector scaling (Faster R-CNN / FPN style — short side target, long side cap) _DET_TARGET_SHORT = 800 # scale shortest side to this _DET_MAX_LONG = 1333 # cap longest side (Faster R-CNN standard) # NMS thresholds (from DLL protobuf config) _NMS_IOU_THRESH = 0.2 # textline_nms_threshold (DLL: field 10 = 0.2) # Recognizer settings _REC_TARGET_H = 60 # target height for recognizer input _REC_MIN_WIDTH = 32 # minimum width after resize # Line grouping _LINE_IOU_Y_THRESH = 0.5 # Y-overlap threshold for same-line grouping _LINE_MERGE_GAP = 2.0 # max gap between words on same line (as ratio of avg char height) # Script names for ScriptID output (10 classes) SCRIPT_NAMES = [ "Latin", "CJK", "Arabic", "Cyrillic", "Devanagari", "Greek", "Hebrew", "Thai", "Tamil", "Unknown" ] # Model index → (role, script, char2ind_file, rnn_info_file) MODEL_REGISTRY: dict[int, tuple[str, str, str | None]] = { 0: ("detector", "universal", None), 1: ("script_id", "universal", None), 2: ("recognizer", "Latin", "chunk_37_char2ind.char2ind.txt"), 3: ("recognizer", "CJK", "chunk_40_char2ind.char2ind.txt"), 4: ("recognizer", "Arabic", "chunk_43_char2ind.char2ind.txt"), 5: ("recognizer", "Cyrillic", "chunk_47_char2ind.char2ind.txt"), 6: ("recognizer", "Devanagari", "chunk_50_char2ind.char2ind.txt"), 7: ("recognizer", "Greek", "chunk_53_char2ind.char2ind.txt"), 8: ("recognizer", "Hebrew", "chunk_57_char2ind.char2ind.txt"), 9: ("recognizer", "Tamil", "chunk_61_char2ind.char2ind.txt"), 10: ("recognizer", "Thai", "chunk_64_char2ind.char2ind.txt"), } # Script name → recognizer model index SCRIPT_TO_MODEL: dict[str, int] = { "Latin": 2, "CJK": 3, "Arabic": 4, "Cyrillic": 5, "Devanagari": 6, "Greek": 7, "Hebrew": 8, "Tamil": 9, "Thai": 10, } # ─── Helper: Character map ────────────────────────────────────────────────── def load_char_map(path: str | Path) -> tuple[dict[int, str], int]: """Load char2ind.txt → (idx→char map, blank_index). Format: ' ' per line. Special tokens: =space char, =CTC blank. """ idx2char: dict[int, str] = {} blank_idx = 0 with open(path, "r", encoding="utf-8") as f: for raw_line in f: line = raw_line.rstrip("\n") if not line: continue sp = line.rfind(" ") if sp <= 0: continue char_str, idx_str = line[:sp], line[sp + 1:] idx = int(idx_str) if char_str == "": blank_idx = idx elif char_str == "": idx2char[idx] = " " else: idx2char[idx] = char_str return idx2char, blank_idx # ─── PixelLink Post-Processing ────────────────────────────────────────────── # 8-connected neighbors: (dy, dx) for N, NE, E, SE, S, SW, W, NW _NEIGHBOR_OFFSETS = [(-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (-1, -1)] class _UnionFind: """Union-Find (Disjoint Set) for connected component labeling.""" __slots__ = ("parent", "rank") def __init__(self, n: int): self.parent = list(range(n)) self.rank = [0] * n def find(self, x: int) -> int: while self.parent[x] != x: self.parent[x] = self.parent[self.parent[x]] x = self.parent[x] return x def union(self, a: int, b: int) -> None: ra, rb = self.find(a), self.find(b) if ra == rb: return if self.rank[ra] < self.rank[rb]: ra, rb = rb, ra self.parent[rb] = ra if self.rank[ra] == self.rank[rb]: self.rank[ra] += 1 def _pixellink_decode( pixel_scores: np.ndarray, link_scores: np.ndarray, bbox_deltas: np.ndarray, stride: int, pixel_thresh: float = _PIXEL_SCORE_THRESH, link_thresh: float = _LINK_SCORE_THRESH, min_area: float = _MIN_AREA, min_component_pixels: int = 3, ) -> list[np.ndarray]: """Decode PixelLink outputs into oriented bounding boxes. Uses connected-component labeling with Union-Find for linked text pixels, then refines box positions using bbox_deltas regression (matching DLL behavior). Each text pixel predicts 4 corner offsets (as fraction of stride) via bbox_deltas[8, H, W] = [TL.x, TL.y, TR.x, TR.y, BR.x, BR.y, BL.x, BL.y]. The actual corner position for pixel (r, c) is: corner = (pixel_coord + delta) * stride For a connected component, the bounding box corners are computed by taking the extremes of all per-pixel predictions (min TL, max BR). Args: pixel_scores: [H, W] text/non-text scores (already sigmoid'd) link_scores: [8, H, W] neighbor link scores (already sigmoid'd) bbox_deltas: [8, H, W] corner offsets — 4 corners × 2 coords (x, y) stride: FPN stride (4, 8, or 16) min_area: minimum box area in detector-image pixels min_component_pixels: minimum number of pixels in a connected component Returns: List of (4, 2) arrays — quadrilateral corners in detector-image coordinates. """ h, w = pixel_scores.shape # Step 1: Threshold pixels text_mask = pixel_scores > pixel_thresh text_pixels = np.argwhere(text_mask) # (N, 2) — (row, col) pairs if len(text_pixels) == 0: return [] # Step 2: Build pixel index map for quick lookup pixel_map = np.full((h, w), -1, dtype=np.int32) for i, (r, c) in enumerate(text_pixels): pixel_map[r, c] = i # Step 3: Union-Find to group linked pixels uf = _UnionFind(len(text_pixels)) for i, (r, c) in enumerate(text_pixels): for ni, (dy, dx) in enumerate(_NEIGHBOR_OFFSETS): nr, nc = r + dy, c + dx if 0 <= nr < h and 0 <= nc < w: j = pixel_map[nr, nc] if j >= 0 and link_scores[ni, r, c] > link_thresh: uf.union(i, j) # Step 4: Group pixels by component components: dict[int, list[int]] = {} for i in range(len(text_pixels)): root = uf.find(i) if root not in components: components[root] = [] components[root].append(i) # Step 5: For each component, compute bbox using delta regression quads: list[np.ndarray] = [] for indices in components.values(): if len(indices) < min_component_pixels: continue # Compute per-pixel corner predictions using bbox_deltas # Each pixel at (r, c) predicts 4 corners: # TL = ((c + d0) * stride, (r + d1) * stride) # TR = ((c + d2) * stride, (r + d3) * stride) # BR = ((c + d4) * stride, (r + d5) * stride) # BL = ((c + d6) * stride, (r + d7) * stride) tl_x_min = float("inf") tl_y_min = float("inf") br_x_max = float("-inf") br_y_max = float("-inf") for idx in indices: r, c = int(text_pixels[idx][0]), int(text_pixels[idx][1]) # TL corner tl_x = (c + float(bbox_deltas[0, r, c])) * stride tl_y = (r + float(bbox_deltas[1, r, c])) * stride # TR corner tr_x = (c + float(bbox_deltas[2, r, c])) * stride # BR corner br_x = (c + float(bbox_deltas[4, r, c])) * stride br_y = (r + float(bbox_deltas[5, r, c])) * stride # BL corner bl_y = (r + float(bbox_deltas[7, r, c])) * stride tl_x_min = min(tl_x_min, tl_x) tl_y_min = min(tl_y_min, tl_y, (r + float(bbox_deltas[3, r, c])) * stride) br_x_max = max(br_x_max, br_x, tr_x) br_y_max = max(br_y_max, br_y, bl_y) # Clamp to positive coordinates x1 = max(0.0, tl_x_min) y1 = max(0.0, tl_y_min) x2 = br_x_max y2 = br_y_max box_w = x2 - x1 box_h = y2 - y1 area = box_w * box_h if area < min_area: continue if box_h < _MIN_HEIGHT: continue if box_w < _MIN_WIDTH: continue # Create axis-aligned quad (TL, TR, BR, BL) quad = np.array([ [x1, y1], [x2, y1], [x2, y2], [x1, y2], ], dtype=np.float32) quads.append(quad) return quads def _order_corners(pts: np.ndarray) -> np.ndarray: """Order 4 corners as: top-left, top-right, bottom-right, bottom-left.""" # Sort by y first, then by x s = pts.sum(axis=1) d = np.diff(pts, axis=1).ravel() ordered = np.zeros((4, 2), dtype=np.float32) ordered[0] = pts[np.argmin(s)] # top-left: smallest sum ordered[2] = pts[np.argmax(s)] # bottom-right: largest sum ordered[1] = pts[np.argmin(d)] # top-right: smallest diff ordered[3] = pts[np.argmax(d)] # bottom-left: largest diff return ordered def _nms_quads(quads: list[np.ndarray], iou_thresh: float = 0.3) -> list[np.ndarray]: """Non-maximum suppression on quadrilateral boxes using contour IoU.""" if len(quads) <= 1: return quads # Sort by area (largest first) areas = [cv2.contourArea(q) for q in quads] order = np.argsort(areas)[::-1] keep: list[np.ndarray] = [] used = set() for i in order: if i in used: continue keep.append(quads[i]) used.add(i) for j in order: if j in used: continue # Compute IoU between quads[i] and quads[j] iou = _quad_iou(quads[i], quads[j]) if iou > iou_thresh: # Merge: keep the larger one used.add(j) return keep def _quad_iou(q1: np.ndarray, q2: np.ndarray) -> float: """Compute IoU between two quadrilaterals.""" try: ret, region = cv2.intersectConvexConvex( q1.astype(np.float32), q2.astype(np.float32) ) if ret <= 0: return 0.0 inter = cv2.contourArea(region) a1 = cv2.contourArea(q1) a2 = cv2.contourArea(q2) union = a1 + a2 - inter return inter / union if union > 0 else 0.0 except Exception: return 0.0 # ─── CTC Decoding ─────────────────────────────────────────────────────────── def ctc_greedy_decode( logprobs: np.ndarray, idx2char: dict[int, str], blank_idx: int, ) -> tuple[str, float, list[float]]: """CTC greedy decode: argmax per timestep, merge repeats, remove blanks. Returns (decoded_text, average_confidence, per_char_confidences). """ if logprobs.ndim == 3: logprobs = logprobs[:, 0, :] indices = np.argmax(logprobs, axis=-1) probs = np.exp(logprobs) max_probs = probs[np.arange(len(indices)), indices] chars: list[str] = [] char_probs: list[float] = [] prev = -1 for t, idx in enumerate(indices): if idx != prev and idx != blank_idx: chars.append(idx2char.get(int(idx), f"[{idx}]")) char_probs.append(float(max_probs[t])) prev = idx text = "".join(chars) confidence = float(np.mean(char_probs)) if char_probs else 0.0 return text, confidence, char_probs # ─── Line Grouping ────────────────────────────────────────────────────────── def _group_words_into_lines( words: list[tuple[str, np.ndarray, float]], ) -> list[list[tuple[str, np.ndarray, float]]]: """Group detected words into lines based on Y-overlap. Args: words: List of (text, quad[4,2], confidence). Returns: List of line groups, each a list of (text, quad, conf) tuples. """ if not words: return [] # Sort by center Y words_with_cy = [] for w in words: _, quad, _ = w cy = quad[:, 1].mean() words_with_cy.append((cy, w)) words_with_cy.sort(key=lambda x: x[0]) lines: list[list[tuple[str, np.ndarray, float]]] = [] used = set() for i, (cy_i, w_i) in enumerate(words_with_cy): if i in used: continue _, qi, _ = w_i y_min_i = qi[:, 1].min() y_max_i = qi[:, 1].max() h_i = y_max_i - y_min_i line = [w_i] used.add(i) for j in range(i + 1, len(words_with_cy)): if j in used: continue _, w_j = words_with_cy[j] _, qj, _ = w_j y_min_j = qj[:, 1].min() y_max_j = qj[:, 1].max() # Check Y overlap overlap = min(y_max_i, y_max_j) - max(y_min_i, y_min_j) min_height = min(y_max_i - y_min_i, y_max_j - y_min_j) if min_height > 0 and overlap / min_height > _LINE_IOU_Y_THRESH: line.append(w_j) used.add(j) # Expand y range y_min_i = min(y_min_i, y_min_j) y_max_i = max(y_max_i, y_max_j) # Sort line words by X position (left to right) line.sort(key=lambda w: w[1][:, 0].min()) lines.append(line) return lines # ─── Image Angle Detection ────────────────────────────────────────────────── def _estimate_image_angle(quads: list[np.ndarray]) -> float: """Estimate overall text angle from detected quads. Uses the average angle of the top edges of all boxes. """ if not quads: return 0.0 angles = [] for q in quads: # Top edge: q[0] → q[1] dx = q[1][0] - q[0][0] dy = q[1][1] - q[0][1] if abs(dx) < 1: continue angle = math.degrees(math.atan2(dy, dx)) angles.append(angle) if not angles: return 0.0 return float(np.median(angles)) # ═══════════════════════════════════════════════════════════════════════════ # Main Engine # ═══════════════════════════════════════════════════════════════════════════ class OcrEngineOnnx: """Cross-platform OCR engine using extracted ONNX models. Provides the same API as OcrEngine (DLL version) but runs on any OS. Args: models_dir: Path to extracted ONNX models directory. config_dir: Path to extracted config data directory. providers: ONNX Runtime providers (default: CPU only). """ def __init__( self, models_dir: str | Path | None = None, config_dir: str | Path | None = None, providers: list[str] | None = None, ) -> None: base = Path(__file__).resolve().parent.parent self._models_dir = Path(models_dir) if models_dir else base / "oneocr_extracted" / "onnx_models" self._config_dir = Path(config_dir) if config_dir else base / "oneocr_extracted" / "config_data" self._unlocked_dir = self._models_dir.parent / "onnx_models_unlocked" self._providers = providers or ["CPUExecutionProvider"] # Optimized session options (matching DLL's ORT configuration) self._sess_opts = ort.SessionOptions() self._sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self._sess_opts.intra_op_num_threads = max(1, (os.cpu_count() or 4) // 2) self._sess_opts.inter_op_num_threads = 2 self._sess_opts.enable_mem_pattern = True self._sess_opts.enable_cpu_mem_arena = True # Lazy-loaded sessions self._detector: ort.InferenceSession | None = None self._script_id: ort.InferenceSession | None = None self._recognizers: dict[int, ort.InferenceSession] = {} self._char_maps: dict[int, tuple[dict[int, str], int]] = {} self._line_layout: ort.InferenceSession | None = None # Validate paths exist if not self._models_dir.exists(): raise FileNotFoundError(f"Models directory not found: {self._models_dir}") if not self._config_dir.exists(): raise FileNotFoundError(f"Config directory not found: {self._config_dir}") # ─── Public API ─────────────────────────────────────────────────────── def recognize_pil(self, image: "Image.Image") -> OcrResult: """Run OCR on a PIL Image. Args: image: PIL Image (any mode — will be converted to RGB). Returns: OcrResult with recognized text, lines, words, bounding boxes, confidence values, and detected text angle. """ if any(x < 10 or x > 10000 for x in image.size): return OcrResult(error="Unsupported image size (must be 10-10000px)") img_rgb = image.convert("RGB") img_arr = np.array(img_rgb) try: return self._run_pipeline(img_arr) except Exception as e: return OcrResult(error=f"Pipeline error: {e}") def recognize_bytes(self, image_bytes: bytes) -> OcrResult: """Run OCR on raw image bytes (PNG/JPEG/etc).""" from io import BytesIO from PIL import Image as PILImage img = PILImage.open(BytesIO(image_bytes)) return self.recognize_pil(img) def recognize_numpy(self, img_rgb: np.ndarray) -> OcrResult: """Run OCR on a numpy array (H, W, 3) in RGB format.""" if img_rgb.ndim != 3 or img_rgb.shape[2] != 3: return OcrResult(error="Expected (H, W, 3) RGB array") try: return self._run_pipeline(img_rgb) except Exception as e: return OcrResult(error=f"Pipeline error: {e}") # ─── Pipeline ───────────────────────────────────────────────────────── def _run_pipeline(self, img_rgb: np.ndarray) -> OcrResult: """Full OCR pipeline: detect → crop → scriptID → recognize → group.""" h, w = img_rgb.shape[:2] # Step 1: Detect text regions quads, scale = self._detect(img_rgb) if not quads: return OcrResult(text="", text_angle=0.0, lines=[]) # Step 2: Estimate image angle text_angle = _estimate_image_angle(quads) # Step 3: For each detected region (line), crop and recognize line_results: list[tuple[str, np.ndarray, float, list[float]]] = [] for quad in quads: # Crop text region from original image crop = self._crop_quad(img_rgb, quad) if crop is None or crop.shape[0] < 5 or crop.shape[1] < 5: continue # Detect vertical text: if crop is much taller than wide, rotate 90° CCW ch, cw = crop.shape[:2] is_vertical = ch > cw * 2 rec_crop = crop if is_vertical: rec_crop = cv2.rotate(crop, cv2.ROTATE_90_COUNTERCLOCKWISE) # ScriptID (on properly oriented crop) script_idx = self._identify_script(rec_crop) script_name = SCRIPT_NAMES[script_idx] if script_idx < len(SCRIPT_NAMES) else "Latin" # Map to recognizer model_idx = SCRIPT_TO_MODEL.get(script_name, 2) # default Latin # Recognize full line text, conf, char_confs = self._recognize(rec_crop, model_idx) # For vertical text fallback: if confidence is low, try CJK recognizer if is_vertical and conf < 0.7 and model_idx != 3: text_cjk, conf_cjk, char_confs_cjk = self._recognize(rec_crop, 3) if conf_cjk > conf: text, conf, char_confs = text_cjk, conf_cjk, char_confs_cjk if text.strip(): # Graded confidence filter. # Short noise detections from FPN2 (single chars like "1", "C") # typically have lower confidence than genuine text like "I", "A". text_stripped = text.strip() n_chars = len(text_stripped) if n_chars <= 1 and conf < 0.35: continue # single char needs some confidence elif conf < 0.3: continue # very low confidence = noise line_results.append((text, quad, conf, char_confs)) if not line_results: return OcrResult(text="", text_angle=text_angle, lines=[]) # Step 4: Build OcrResult — split recognized text into words lines: list[OcrLine] = [] for line_text, quad, conf, char_confs in line_results: # Split text by spaces to get words word_texts = line_text.split() if not word_texts: continue # Estimate per-word bounding boxes and confidence by distributing # quad width and char confidences proportionally to words words = self._split_into_words(word_texts, quad, conf, char_confs) line_bbox = BoundingRect( x1=float(quad[0][0]), y1=float(quad[0][1]), x2=float(quad[1][0]), y2=float(quad[1][1]), x3=float(quad[2][0]), y3=float(quad[2][1]), x4=float(quad[3][0]), y4=float(quad[3][1]), ) full_line_text = " ".join(w.text for w in words) lines.append(OcrLine(text=full_line_text, bounding_rect=line_bbox, words=words)) # Sort lines top-to-bottom by center Y coordinate lines.sort(key=lambda l: ( (l.bounding_rect.y1 + l.bounding_rect.y3) / 2 if l.bounding_rect else 0 )) full_text = "\n".join(line.text for line in lines) return OcrResult(text=full_text, text_angle=text_angle, lines=lines) # ─── Detection ──────────────────────────────────────────────────────── def _detect(self, img_rgb: np.ndarray) -> tuple[list[np.ndarray], float]: """Run PixelLink detector and decode bounding boxes. Scaling: FPN-style — scale shortest side to 800, cap longest at 1333. Two-phase detection: - Phase 1: FPN3 (stride=8) + FPN4 (stride=16) — primary detections - Phase 2: FPN2 (stride=4) — supplementary for small text ("I", "...") Only keeps novel FPN2 detections that don't overlap with primary. NMS: IoU threshold 0.2 (from DLL protobuf config). Returns (list_of_quads, scale_factor). """ h, w = img_rgb.shape[:2] # FPN-style scaling: target short side = 800, cap long side at 1333 short_side = min(h, w) long_side = max(h, w) scale = _DET_TARGET_SHORT / short_side if long_side * scale > _DET_MAX_LONG: scale = _DET_MAX_LONG / long_side scale = min(scale, 6.0) dh = (int(h * scale) + 31) // 32 * 32 dw = (int(w * scale) + 31) // 32 * 32 img_resized = cv2.resize(img_rgb, (dw, dh), interpolation=cv2.INTER_LINEAR) # BGR + mean subtraction (ImageNet-style, matching DLL) img_bgr = img_resized[:, :, ::-1].astype(np.float32) - _DET_MEAN data = img_bgr.transpose(2, 0, 1)[np.newaxis] im_info = np.array([[dh, dw, scale]], dtype=np.float32) # Run detector sess = self._get_detector() outputs = sess.run(None, {"data": data, "im_info": im_info}) output_names = [o.name for o in sess.get_outputs()] out_dict = dict(zip(output_names, outputs)) # ── Phase 1: Primary detections from FPN3 + FPN4 ── primary_quads: list[np.ndarray] = [] for level, stride in [("fpn3", 8), ("fpn4", 16)]: min_area_scaled = _MIN_AREA * (scale ** 2) for orientation in ("hori", "vert"): scores_key = f"scores_{orientation}_{level}" links_key = f"link_scores_{orientation}_{level}" deltas_key = f"bbox_deltas_{orientation}_{level}" if scores_key not in out_dict: continue pixel_scores = out_dict[scores_key][0, 0] if orientation == "vert" and pixel_scores.max() <= _PIXEL_SCORE_THRESH: continue link_scores = out_dict[links_key][0] bbox_deltas = out_dict[deltas_key][0] quads = _pixellink_decode( pixel_scores, link_scores, bbox_deltas, stride, pixel_thresh=_PIXEL_SCORE_THRESH, min_area=min_area_scaled, ) primary_quads.extend(quads) primary_quads = _nms_quads(primary_quads, iou_thresh=_NMS_IOU_THRESH) # ── Phase 2: FPN2 supplementary detections ── # Higher threshold (0.85) to reduce false positives from panel borders. # Only keep novel detections that don't overlap with primary. fpn2_quads: list[np.ndarray] = [] min_area_scaled = _MIN_AREA * (scale ** 2) for orientation in ("hori", "vert"): scores_key = f"scores_{orientation}_fpn2" links_key = f"link_scores_{orientation}_fpn2" deltas_key = f"bbox_deltas_{orientation}_fpn2" if scores_key not in out_dict: continue pixel_scores = out_dict[scores_key][0, 0] if pixel_scores.max() <= 0.85: continue link_scores = out_dict[links_key][0] bbox_deltas = out_dict[deltas_key][0] quads = _pixellink_decode( pixel_scores, link_scores, bbox_deltas, 4, pixel_thresh=0.85, min_area=min_area_scaled, min_component_pixels=5, ) for q in quads: # Only keep if not overlapping with primary detections overlaps = any(_quad_iou(q, p) > 0.1 for p in primary_quads) if not overlaps: fpn2_quads.append(q) # Combine and final NMS all_quads = primary_quads + fpn2_quads if fpn2_quads: all_quads = _nms_quads(all_quads, iou_thresh=_NMS_IOU_THRESH) # Scale quads back to original image coordinates for i in range(len(all_quads)): all_quads[i] = all_quads[i] / scale return all_quads, scale # ─── Script Identification ──────────────────────────────────────────── def _identify_script(self, crop_rgb: np.ndarray) -> int: """Identify script of a cropped text region. Returns script index (0=Latin, 1=CJK, ..., 9=Unknown). """ sess = self._get_script_id() # Preprocess: RGB -> height=60 -> /255 -> NCHW data = self._preprocess_recognizer(crop_rgb) outputs = sess.run(None, {"data": data}) # Output: script_id_score [1, 1, 10] script_scores = outputs[3] # script_id_score script_idx = int(np.argmax(script_scores.flatten()[:10])) return script_idx # ─── Recognition ────────────────────────────────────────────────────── def _recognize( self, crop_rgb: np.ndarray, model_idx: int ) -> tuple[str, float, list[float]]: """Recognize text in a cropped region using the specified model. Returns (text, confidence, per_char_confidences). """ sess = self._get_recognizer(model_idx) idx2char, blank_idx = self._get_char_map(model_idx) data = self._preprocess_recognizer(crop_rgb) h, w = data.shape[2], data.shape[3] seq_lengths = np.array([w // 4], dtype=np.int32) logprobs = sess.run(None, {"data": data, "seq_lengths": seq_lengths})[0] text, conf, char_confs = ctc_greedy_decode(logprobs, idx2char, blank_idx) return text, conf, char_confs # ─── Word splitting ───────────────────────────────────────────────── @staticmethod def _split_into_words( word_texts: list[str], quad: np.ndarray, confidence: float, char_confs: list[float] | None = None, ) -> list[OcrWord]: """Split a line into word-level OcrWord objects with estimated bboxes. Distributes the line quad proportionally by character count. Per-word confidence is computed from character-level CTC confidences. """ if not word_texts: return [] # Include spaces in character counting for positioning total_chars = sum(len(w) for w in word_texts) + len(word_texts) - 1 if total_chars <= 0: total_chars = 1 # Build per-word confidence from char_confs word_confidences: list[float] = [] if char_confs and len(char_confs) >= sum(len(w) for w in word_texts): idx = 0 for word_text in word_texts: wc = char_confs[idx:idx + len(word_text)] word_confidences.append(float(np.mean(wc)) if wc else confidence) idx += len(word_text) # Skip space character confidence (if present in the list) if idx < len(char_confs): idx += 1 # skip space else: word_confidences = [confidence] * len(word_texts) # Interpolate along top edge (q0→q1) and bottom edge (q3→q2) top_start, top_end = quad[0], quad[1] bot_start, bot_end = quad[3], quad[2] words: list[OcrWord] = [] char_pos = 0 for i, word_text in enumerate(word_texts): t_start = char_pos / total_chars t_end = (char_pos + len(word_text)) / total_chars # Interpolate corners tl = top_start + (top_end - top_start) * t_start tr = top_start + (top_end - top_start) * t_end bl = bot_start + (bot_end - bot_start) * t_start br = bot_start + (bot_end - bot_start) * t_end bbox = BoundingRect( x1=float(tl[0]), y1=float(tl[1]), x2=float(tr[0]), y2=float(tr[1]), x3=float(br[0]), y3=float(br[1]), x4=float(bl[0]), y4=float(bl[1]), ) words.append(OcrWord( text=word_text, bounding_rect=bbox, confidence=word_confidences[i], )) char_pos += len(word_text) + 1 # +1 for space return words # ─── Preprocessing ──────────────────────────────────────────────────── @staticmethod def _preprocess_recognizer(img_rgb: np.ndarray) -> np.ndarray: """Preprocess image for recognizer/scriptID input. Process: Resize height to 60px → RGB / 255.0 → NCHW float32. """ h, w = img_rgb.shape[:2] target_h = _REC_TARGET_H scale = target_h / h new_w = max(int(w * scale), _REC_MIN_WIDTH) new_w = (new_w + 3) // 4 * 4 # align to 4 resized = cv2.resize(img_rgb, (new_w, target_h), interpolation=cv2.INTER_LINEAR) data = resized.astype(np.float32) / 255.0 data = data.transpose(2, 0, 1)[np.newaxis] # HWC → NCHW return data # ─── Crop ───────────────────────────────────────────────────────────── @staticmethod def _crop_quad( img_rgb: np.ndarray, quad: np.ndarray, padding_ratio: float = 0.15, ) -> np.ndarray | None: """Crop a text region from the image using the bounding quad. Uses axis-aligned rectangle crop with padding (matching DLL behavior). Falls back to perspective transform only for heavily rotated text (>15°). Args: img_rgb: Source image (H, W, 3). quad: 4 corners as (4, 2) array. padding_ratio: How much to expand (fraction of height). Returns: Cropped RGB image or None. """ try: img_h, img_w = img_rgb.shape[:2] # Check rotation angle of top edge dx = quad[1][0] - quad[0][0] dy = quad[1][1] - quad[0][1] angle = abs(math.atan2(dy, dx)) * 180 / math.pi # For near-horizontal text (<15°), use simple axis-aligned rectangle if angle < 15: # Axis-aligned bounding box from quad x_min = float(quad[:, 0].min()) x_max = float(quad[:, 0].max()) y_min = float(quad[:, 1].min()) y_max = float(quad[:, 1].max()) box_h = y_max - y_min box_w = x_max - x_min if box_h < 3 or box_w < 5: return None # Apply padding (use height-proportional padding on ALL sides) # Minimum 3px padding for very small text regions pad_h = max(box_h * padding_ratio, 3) pad_w = max(box_h * 0.25, 3) # wider horizontal padding for stride alignment y1 = max(0, int(y_min - pad_h)) y2 = min(img_h, int(y_max + pad_h)) x1 = max(0, int(x_min - pad_w)) x2 = min(img_w, int(x_max + pad_w)) if y2 - y1 < 3 or x2 - x1 < 5: return None return img_rgb[y1:y2, x1:x2].copy() # For rotated text, use perspective transform w1 = np.linalg.norm(quad[1] - quad[0]) w2 = np.linalg.norm(quad[2] - quad[3]) h1 = np.linalg.norm(quad[3] - quad[0]) h2 = np.linalg.norm(quad[2] - quad[1]) target_w = int(max(w1, w2)) target_h = int(max(h1, h2)) if target_w < 5 or target_h < 3: return None # Expand the quad by padding pad = max(h1, h2) * padding_ratio top_dir = quad[0] - quad[3] if np.linalg.norm(top_dir) > 0: top_dir = top_dir / np.linalg.norm(top_dir) else: top_dir = np.array([0, -1], dtype=np.float32) left_dir = quad[0] - quad[1] if np.linalg.norm(left_dir) > 0: left_dir = left_dir / np.linalg.norm(left_dir) else: left_dir = np.array([-1, 0], dtype=np.float32) expanded = quad.copy().astype(np.float32) expanded[0] = quad[0] + top_dir * pad + left_dir * pad * 0.3 expanded[1] = quad[1] + top_dir * pad - left_dir * pad * 0.3 expanded[2] = quad[2] - top_dir * pad - left_dir * pad * 0.3 expanded[3] = quad[3] - top_dir * pad + left_dir * pad * 0.3 expanded[:, 0] = np.clip(expanded[:, 0], 0, img_w - 1) expanded[:, 1] = np.clip(expanded[:, 1], 0, img_h - 1) w1 = np.linalg.norm(expanded[1] - expanded[0]) w2 = np.linalg.norm(expanded[2] - expanded[3]) h1 = np.linalg.norm(expanded[3] - expanded[0]) h2 = np.linalg.norm(expanded[2] - expanded[1]) target_w = int(max(w1, w2)) target_h = int(max(h1, h2)) if target_w < 5 or target_h < 3: return None dst = np.array([ [0, 0], [target_w - 1, 0], [target_w - 1, target_h - 1], [0, target_h - 1], ], dtype=np.float32) M = cv2.getPerspectiveTransform(expanded.astype(np.float32), dst) crop = cv2.warpPerspective( img_rgb, M, (target_w, target_h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE, ) return crop except Exception: return None # ─── Model Loading ──────────────────────────────────────────────────── def _find_model(self, model_idx: int) -> Path: """Find ONNX model file by index. Checks unlocked dir first for models 11-33.""" if model_idx >= 11 and self._unlocked_dir.exists(): matches = list(self._unlocked_dir.glob(f"model_{model_idx:02d}_*")) if matches: return matches[0] matches = list(self._models_dir.glob(f"model_{model_idx:02d}_*")) if not matches: raise FileNotFoundError(f"Model file not found for index {model_idx}") return matches[0] def _get_detector(self) -> ort.InferenceSession: """Get or create detector session.""" if self._detector is None: path = self._find_model(0) self._detector = ort.InferenceSession( str(path), sess_options=self._sess_opts, providers=self._providers ) return self._detector def _get_script_id(self) -> ort.InferenceSession: """Get or create ScriptID session.""" if self._script_id is None: path = self._find_model(1) self._script_id = ort.InferenceSession( str(path), sess_options=self._sess_opts, providers=self._providers ) return self._script_id def _get_recognizer(self, model_idx: int) -> ort.InferenceSession: """Get or create recognizer session.""" if model_idx not in self._recognizers: path = self._find_model(model_idx) self._recognizers[model_idx] = ort.InferenceSession( str(path), sess_options=self._sess_opts, providers=self._providers ) return self._recognizers[model_idx] def _get_char_map(self, model_idx: int) -> tuple[dict[int, str], int]: """Get or load character map for model.""" if model_idx not in self._char_maps: info = MODEL_REGISTRY.get(model_idx) if not info or not info[2]: raise ValueError(f"No char2ind file for model {model_idx}") char_path = self._config_dir / info[2] self._char_maps[model_idx] = load_char_map(char_path) return self._char_maps[model_idx] def _get_line_layout(self) -> ort.InferenceSession | None: """Get or create LineLayout session (model 33). Returns None if unavailable.""" if self._line_layout is None: try: path = self._find_model(33) self._line_layout = ort.InferenceSession( str(path), sess_options=self._sess_opts, providers=self._providers ) except FileNotFoundError: return None return self._line_layout def _run_line_layout(self, crop_rgb: np.ndarray) -> float | None: """Run LineLayout model to get line boundary score. Args: crop_rgb: Cropped text line image. Returns: Line layout score (higher = more confident this is a complete line), or None if LineLayout model unavailable. """ sess = self._get_line_layout() if sess is None: return None try: data = self._preprocess_recognizer(crop_rgb) outputs = sess.run(None, {"data": data}) # output[1] = line_layout_score [1, 1, 1] score = float(outputs[1].flatten()[0]) return score except Exception: return None