|
|
""" |
|
|
Cross-platform OneOCR engine — pure Python/ONNX implementation. |
|
|
|
|
|
Reimplements the full OneOCR DLL pipeline using extracted ONNX models: |
|
|
1. Detector (model_00): PixelLink FPN text detection → bounding boxes |
|
|
2. ScriptID (model_01): Script/handwriting/flip classification |
|
|
3. Recognizers (02-10): Per-script CTC character recognition |
|
|
4. Line grouping: Heuristic grouping of words into lines |
|
|
|
|
|
Preprocessing: |
|
|
- Detector: BGR, mean-subtracted [102.98, 115.95, 122.77], NCHW |
|
|
- Recognizers: RGB / 255.0, height=60px, NCHW |
|
|
- ScriptID: Same as recognizers |
|
|
|
|
|
Output format matches the DLL: OcrResult with BoundingRect, OcrLine, OcrWord. |
|
|
|
|
|
Usage: |
|
|
from ocr.engine_onnx import OcrEngineOnnx |
|
|
engine = OcrEngineOnnx() |
|
|
result = engine.recognize_pil(pil_image) |
|
|
print(result.text) |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
|
|
|
from ocr.models import BoundingRect, OcrLine, OcrResult, OcrWord |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_DET_MEAN = np.array([102.9801, 115.9465, 122.7717], dtype=np.float32) |
|
|
|
|
|
|
|
|
_PIXEL_SCORE_THRESH = 0.7 |
|
|
_LINK_SCORE_THRESH = 0.5 |
|
|
_MIN_AREA = 50 |
|
|
_MIN_HEIGHT = 5 |
|
|
_MIN_WIDTH = 4 |
|
|
|
|
|
|
|
|
_DET_TARGET_SHORT = 800 |
|
|
_DET_MAX_LONG = 1333 |
|
|
|
|
|
|
|
|
_NMS_IOU_THRESH = 0.2 |
|
|
|
|
|
|
|
|
_REC_TARGET_H = 60 |
|
|
_REC_MIN_WIDTH = 32 |
|
|
|
|
|
|
|
|
_LINE_IOU_Y_THRESH = 0.5 |
|
|
_LINE_MERGE_GAP = 2.0 |
|
|
|
|
|
|
|
|
SCRIPT_NAMES = [ |
|
|
"Latin", "CJK", "Arabic", "Cyrillic", "Devanagari", |
|
|
"Greek", "Hebrew", "Thai", "Tamil", "Unknown" |
|
|
] |
|
|
|
|
|
|
|
|
MODEL_REGISTRY: dict[int, tuple[str, str, str | None]] = { |
|
|
0: ("detector", "universal", None), |
|
|
1: ("script_id", "universal", None), |
|
|
2: ("recognizer", "Latin", "chunk_37_char2ind.char2ind.txt"), |
|
|
3: ("recognizer", "CJK", "chunk_40_char2ind.char2ind.txt"), |
|
|
4: ("recognizer", "Arabic", "chunk_43_char2ind.char2ind.txt"), |
|
|
5: ("recognizer", "Cyrillic", "chunk_47_char2ind.char2ind.txt"), |
|
|
6: ("recognizer", "Devanagari", "chunk_50_char2ind.char2ind.txt"), |
|
|
7: ("recognizer", "Greek", "chunk_53_char2ind.char2ind.txt"), |
|
|
8: ("recognizer", "Hebrew", "chunk_57_char2ind.char2ind.txt"), |
|
|
9: ("recognizer", "Tamil", "chunk_61_char2ind.char2ind.txt"), |
|
|
10: ("recognizer", "Thai", "chunk_64_char2ind.char2ind.txt"), |
|
|
} |
|
|
|
|
|
|
|
|
SCRIPT_TO_MODEL: dict[str, int] = { |
|
|
"Latin": 2, "CJK": 3, "Arabic": 4, "Cyrillic": 5, |
|
|
"Devanagari": 6, "Greek": 7, "Hebrew": 8, "Tamil": 9, "Thai": 10, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_char_map(path: str | Path) -> tuple[dict[int, str], int]: |
|
|
"""Load char2ind.txt → (idx→char map, blank_index). |
|
|
|
|
|
Format: '<char> <index>' per line. |
|
|
Special tokens: <space>=space char, <blank>=CTC blank. |
|
|
""" |
|
|
idx2char: dict[int, str] = {} |
|
|
blank_idx = 0 |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
for raw_line in f: |
|
|
line = raw_line.rstrip("\n") |
|
|
if not line: |
|
|
continue |
|
|
sp = line.rfind(" ") |
|
|
if sp <= 0: |
|
|
continue |
|
|
char_str, idx_str = line[:sp], line[sp + 1:] |
|
|
idx = int(idx_str) |
|
|
if char_str == "<blank>": |
|
|
blank_idx = idx |
|
|
elif char_str == "<space>": |
|
|
idx2char[idx] = " " |
|
|
else: |
|
|
idx2char[idx] = char_str |
|
|
return idx2char, blank_idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_NEIGHBOR_OFFSETS = [(-1, 0), (-1, 1), (0, 1), (1, 1), |
|
|
(1, 0), (1, -1), (0, -1), (-1, -1)] |
|
|
|
|
|
|
|
|
class _UnionFind: |
|
|
"""Union-Find (Disjoint Set) for connected component labeling.""" |
|
|
|
|
|
__slots__ = ("parent", "rank") |
|
|
|
|
|
def __init__(self, n: int): |
|
|
self.parent = list(range(n)) |
|
|
self.rank = [0] * n |
|
|
|
|
|
def find(self, x: int) -> int: |
|
|
while self.parent[x] != x: |
|
|
self.parent[x] = self.parent[self.parent[x]] |
|
|
x = self.parent[x] |
|
|
return x |
|
|
|
|
|
def union(self, a: int, b: int) -> None: |
|
|
ra, rb = self.find(a), self.find(b) |
|
|
if ra == rb: |
|
|
return |
|
|
if self.rank[ra] < self.rank[rb]: |
|
|
ra, rb = rb, ra |
|
|
self.parent[rb] = ra |
|
|
if self.rank[ra] == self.rank[rb]: |
|
|
self.rank[ra] += 1 |
|
|
|
|
|
|
|
|
def _pixellink_decode( |
|
|
pixel_scores: np.ndarray, |
|
|
link_scores: np.ndarray, |
|
|
bbox_deltas: np.ndarray, |
|
|
stride: int, |
|
|
pixel_thresh: float = _PIXEL_SCORE_THRESH, |
|
|
link_thresh: float = _LINK_SCORE_THRESH, |
|
|
min_area: float = _MIN_AREA, |
|
|
min_component_pixels: int = 3, |
|
|
) -> list[np.ndarray]: |
|
|
"""Decode PixelLink outputs into oriented bounding boxes. |
|
|
|
|
|
Uses connected-component labeling with Union-Find for linked text pixels, |
|
|
then refines box positions using bbox_deltas regression (matching DLL behavior). |
|
|
|
|
|
Each text pixel predicts 4 corner offsets (as fraction of stride) via |
|
|
bbox_deltas[8, H, W] = [TL.x, TL.y, TR.x, TR.y, BR.x, BR.y, BL.x, BL.y]. |
|
|
The actual corner position for pixel (r, c) is: |
|
|
corner = (pixel_coord + delta) * stride |
|
|
|
|
|
For a connected component, the bounding box corners are computed by taking |
|
|
the extremes of all per-pixel predictions (min TL, max BR). |
|
|
|
|
|
Args: |
|
|
pixel_scores: [H, W] text/non-text scores (already sigmoid'd) |
|
|
link_scores: [8, H, W] neighbor link scores (already sigmoid'd) |
|
|
bbox_deltas: [8, H, W] corner offsets — 4 corners × 2 coords (x, y) |
|
|
stride: FPN stride (4, 8, or 16) |
|
|
min_area: minimum box area in detector-image pixels |
|
|
min_component_pixels: minimum number of pixels in a connected component |
|
|
|
|
|
Returns: |
|
|
List of (4, 2) arrays — quadrilateral corners in detector-image coordinates. |
|
|
""" |
|
|
h, w = pixel_scores.shape |
|
|
|
|
|
|
|
|
text_mask = pixel_scores > pixel_thresh |
|
|
text_pixels = np.argwhere(text_mask) |
|
|
|
|
|
if len(text_pixels) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
pixel_map = np.full((h, w), -1, dtype=np.int32) |
|
|
for i, (r, c) in enumerate(text_pixels): |
|
|
pixel_map[r, c] = i |
|
|
|
|
|
|
|
|
uf = _UnionFind(len(text_pixels)) |
|
|
|
|
|
for i, (r, c) in enumerate(text_pixels): |
|
|
for ni, (dy, dx) in enumerate(_NEIGHBOR_OFFSETS): |
|
|
nr, nc = r + dy, c + dx |
|
|
if 0 <= nr < h and 0 <= nc < w: |
|
|
j = pixel_map[nr, nc] |
|
|
if j >= 0 and link_scores[ni, r, c] > link_thresh: |
|
|
uf.union(i, j) |
|
|
|
|
|
|
|
|
components: dict[int, list[int]] = {} |
|
|
for i in range(len(text_pixels)): |
|
|
root = uf.find(i) |
|
|
if root not in components: |
|
|
components[root] = [] |
|
|
components[root].append(i) |
|
|
|
|
|
|
|
|
quads: list[np.ndarray] = [] |
|
|
|
|
|
for indices in components.values(): |
|
|
if len(indices) < min_component_pixels: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tl_x_min = float("inf") |
|
|
tl_y_min = float("inf") |
|
|
br_x_max = float("-inf") |
|
|
br_y_max = float("-inf") |
|
|
|
|
|
for idx in indices: |
|
|
r, c = int(text_pixels[idx][0]), int(text_pixels[idx][1]) |
|
|
|
|
|
tl_x = (c + float(bbox_deltas[0, r, c])) * stride |
|
|
tl_y = (r + float(bbox_deltas[1, r, c])) * stride |
|
|
|
|
|
tr_x = (c + float(bbox_deltas[2, r, c])) * stride |
|
|
|
|
|
br_x = (c + float(bbox_deltas[4, r, c])) * stride |
|
|
br_y = (r + float(bbox_deltas[5, r, c])) * stride |
|
|
|
|
|
bl_y = (r + float(bbox_deltas[7, r, c])) * stride |
|
|
|
|
|
tl_x_min = min(tl_x_min, tl_x) |
|
|
tl_y_min = min(tl_y_min, tl_y, (r + float(bbox_deltas[3, r, c])) * stride) |
|
|
br_x_max = max(br_x_max, br_x, tr_x) |
|
|
br_y_max = max(br_y_max, br_y, bl_y) |
|
|
|
|
|
|
|
|
x1 = max(0.0, tl_x_min) |
|
|
y1 = max(0.0, tl_y_min) |
|
|
x2 = br_x_max |
|
|
y2 = br_y_max |
|
|
|
|
|
box_w = x2 - x1 |
|
|
box_h = y2 - y1 |
|
|
area = box_w * box_h |
|
|
|
|
|
if area < min_area: |
|
|
continue |
|
|
if box_h < _MIN_HEIGHT: |
|
|
continue |
|
|
if box_w < _MIN_WIDTH: |
|
|
continue |
|
|
|
|
|
|
|
|
quad = np.array([ |
|
|
[x1, y1], |
|
|
[x2, y1], |
|
|
[x2, y2], |
|
|
[x1, y2], |
|
|
], dtype=np.float32) |
|
|
quads.append(quad) |
|
|
|
|
|
return quads |
|
|
|
|
|
|
|
|
def _order_corners(pts: np.ndarray) -> np.ndarray: |
|
|
"""Order 4 corners as: top-left, top-right, bottom-right, bottom-left.""" |
|
|
|
|
|
s = pts.sum(axis=1) |
|
|
d = np.diff(pts, axis=1).ravel() |
|
|
|
|
|
ordered = np.zeros((4, 2), dtype=np.float32) |
|
|
ordered[0] = pts[np.argmin(s)] |
|
|
ordered[2] = pts[np.argmax(s)] |
|
|
ordered[1] = pts[np.argmin(d)] |
|
|
ordered[3] = pts[np.argmax(d)] |
|
|
return ordered |
|
|
|
|
|
|
|
|
def _nms_quads(quads: list[np.ndarray], iou_thresh: float = 0.3) -> list[np.ndarray]: |
|
|
"""Non-maximum suppression on quadrilateral boxes using contour IoU.""" |
|
|
if len(quads) <= 1: |
|
|
return quads |
|
|
|
|
|
|
|
|
areas = [cv2.contourArea(q) for q in quads] |
|
|
order = np.argsort(areas)[::-1] |
|
|
|
|
|
keep: list[np.ndarray] = [] |
|
|
used = set() |
|
|
|
|
|
for i in order: |
|
|
if i in used: |
|
|
continue |
|
|
keep.append(quads[i]) |
|
|
used.add(i) |
|
|
|
|
|
for j in order: |
|
|
if j in used: |
|
|
continue |
|
|
|
|
|
iou = _quad_iou(quads[i], quads[j]) |
|
|
if iou > iou_thresh: |
|
|
|
|
|
used.add(j) |
|
|
|
|
|
return keep |
|
|
|
|
|
|
|
|
def _quad_iou(q1: np.ndarray, q2: np.ndarray) -> float: |
|
|
"""Compute IoU between two quadrilaterals.""" |
|
|
try: |
|
|
ret, region = cv2.intersectConvexConvex( |
|
|
q1.astype(np.float32), q2.astype(np.float32) |
|
|
) |
|
|
if ret <= 0: |
|
|
return 0.0 |
|
|
inter = cv2.contourArea(region) |
|
|
a1 = cv2.contourArea(q1) |
|
|
a2 = cv2.contourArea(q2) |
|
|
union = a1 + a2 - inter |
|
|
return inter / union if union > 0 else 0.0 |
|
|
except Exception: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_greedy_decode( |
|
|
logprobs: np.ndarray, |
|
|
idx2char: dict[int, str], |
|
|
blank_idx: int, |
|
|
) -> tuple[str, float, list[float]]: |
|
|
"""CTC greedy decode: argmax per timestep, merge repeats, remove blanks. |
|
|
|
|
|
Returns (decoded_text, average_confidence, per_char_confidences). |
|
|
""" |
|
|
if logprobs.ndim == 3: |
|
|
logprobs = logprobs[:, 0, :] |
|
|
|
|
|
indices = np.argmax(logprobs, axis=-1) |
|
|
probs = np.exp(logprobs) |
|
|
max_probs = probs[np.arange(len(indices)), indices] |
|
|
|
|
|
chars: list[str] = [] |
|
|
char_probs: list[float] = [] |
|
|
prev = -1 |
|
|
|
|
|
for t, idx in enumerate(indices): |
|
|
if idx != prev and idx != blank_idx: |
|
|
chars.append(idx2char.get(int(idx), f"[{idx}]")) |
|
|
char_probs.append(float(max_probs[t])) |
|
|
prev = idx |
|
|
|
|
|
text = "".join(chars) |
|
|
confidence = float(np.mean(char_probs)) if char_probs else 0.0 |
|
|
return text, confidence, char_probs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _group_words_into_lines( |
|
|
words: list[tuple[str, np.ndarray, float]], |
|
|
) -> list[list[tuple[str, np.ndarray, float]]]: |
|
|
"""Group detected words into lines based on Y-overlap. |
|
|
|
|
|
Args: |
|
|
words: List of (text, quad[4,2], confidence). |
|
|
|
|
|
Returns: |
|
|
List of line groups, each a list of (text, quad, conf) tuples. |
|
|
""" |
|
|
if not words: |
|
|
return [] |
|
|
|
|
|
|
|
|
words_with_cy = [] |
|
|
for w in words: |
|
|
_, quad, _ = w |
|
|
cy = quad[:, 1].mean() |
|
|
words_with_cy.append((cy, w)) |
|
|
words_with_cy.sort(key=lambda x: x[0]) |
|
|
|
|
|
lines: list[list[tuple[str, np.ndarray, float]]] = [] |
|
|
used = set() |
|
|
|
|
|
for i, (cy_i, w_i) in enumerate(words_with_cy): |
|
|
if i in used: |
|
|
continue |
|
|
|
|
|
_, qi, _ = w_i |
|
|
y_min_i = qi[:, 1].min() |
|
|
y_max_i = qi[:, 1].max() |
|
|
h_i = y_max_i - y_min_i |
|
|
|
|
|
line = [w_i] |
|
|
used.add(i) |
|
|
|
|
|
for j in range(i + 1, len(words_with_cy)): |
|
|
if j in used: |
|
|
continue |
|
|
_, w_j = words_with_cy[j] |
|
|
_, qj, _ = w_j |
|
|
y_min_j = qj[:, 1].min() |
|
|
y_max_j = qj[:, 1].max() |
|
|
|
|
|
|
|
|
overlap = min(y_max_i, y_max_j) - max(y_min_i, y_min_j) |
|
|
min_height = min(y_max_i - y_min_i, y_max_j - y_min_j) |
|
|
if min_height > 0 and overlap / min_height > _LINE_IOU_Y_THRESH: |
|
|
line.append(w_j) |
|
|
used.add(j) |
|
|
|
|
|
y_min_i = min(y_min_i, y_min_j) |
|
|
y_max_i = max(y_max_i, y_max_j) |
|
|
|
|
|
|
|
|
line.sort(key=lambda w: w[1][:, 0].min()) |
|
|
lines.append(line) |
|
|
|
|
|
return lines |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _estimate_image_angle(quads: list[np.ndarray]) -> float: |
|
|
"""Estimate overall text angle from detected quads. |
|
|
|
|
|
Uses the average angle of the top edges of all boxes. |
|
|
""" |
|
|
if not quads: |
|
|
return 0.0 |
|
|
|
|
|
angles = [] |
|
|
for q in quads: |
|
|
|
|
|
dx = q[1][0] - q[0][0] |
|
|
dy = q[1][1] - q[0][1] |
|
|
if abs(dx) < 1: |
|
|
continue |
|
|
angle = math.degrees(math.atan2(dy, dx)) |
|
|
angles.append(angle) |
|
|
|
|
|
if not angles: |
|
|
return 0.0 |
|
|
|
|
|
return float(np.median(angles)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OcrEngineOnnx: |
|
|
"""Cross-platform OCR engine using extracted ONNX models. |
|
|
|
|
|
Provides the same API as OcrEngine (DLL version) but runs on any OS. |
|
|
|
|
|
Args: |
|
|
models_dir: Path to extracted ONNX models directory. |
|
|
config_dir: Path to extracted config data directory. |
|
|
providers: ONNX Runtime providers (default: CPU only). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
models_dir: str | Path | None = None, |
|
|
config_dir: str | Path | None = None, |
|
|
providers: list[str] | None = None, |
|
|
) -> None: |
|
|
base = Path(__file__).resolve().parent.parent |
|
|
self._models_dir = Path(models_dir) if models_dir else base / "oneocr_extracted" / "onnx_models" |
|
|
self._config_dir = Path(config_dir) if config_dir else base / "oneocr_extracted" / "config_data" |
|
|
self._unlocked_dir = self._models_dir.parent / "onnx_models_unlocked" |
|
|
self._providers = providers or ["CPUExecutionProvider"] |
|
|
|
|
|
|
|
|
self._sess_opts = ort.SessionOptions() |
|
|
self._sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
self._sess_opts.intra_op_num_threads = max(1, (os.cpu_count() or 4) // 2) |
|
|
self._sess_opts.inter_op_num_threads = 2 |
|
|
self._sess_opts.enable_mem_pattern = True |
|
|
self._sess_opts.enable_cpu_mem_arena = True |
|
|
|
|
|
|
|
|
self._detector: ort.InferenceSession | None = None |
|
|
self._script_id: ort.InferenceSession | None = None |
|
|
self._recognizers: dict[int, ort.InferenceSession] = {} |
|
|
self._char_maps: dict[int, tuple[dict[int, str], int]] = {} |
|
|
self._line_layout: ort.InferenceSession | None = None |
|
|
|
|
|
|
|
|
if not self._models_dir.exists(): |
|
|
raise FileNotFoundError(f"Models directory not found: {self._models_dir}") |
|
|
if not self._config_dir.exists(): |
|
|
raise FileNotFoundError(f"Config directory not found: {self._config_dir}") |
|
|
|
|
|
|
|
|
|
|
|
def recognize_pil(self, image: "Image.Image") -> OcrResult: |
|
|
"""Run OCR on a PIL Image. |
|
|
|
|
|
Args: |
|
|
image: PIL Image (any mode — will be converted to RGB). |
|
|
|
|
|
Returns: |
|
|
OcrResult with recognized text, lines, words, bounding boxes, |
|
|
confidence values, and detected text angle. |
|
|
""" |
|
|
if any(x < 10 or x > 10000 for x in image.size): |
|
|
return OcrResult(error="Unsupported image size (must be 10-10000px)") |
|
|
|
|
|
img_rgb = image.convert("RGB") |
|
|
img_arr = np.array(img_rgb) |
|
|
|
|
|
try: |
|
|
return self._run_pipeline(img_arr) |
|
|
except Exception as e: |
|
|
return OcrResult(error=f"Pipeline error: {e}") |
|
|
|
|
|
def recognize_bytes(self, image_bytes: bytes) -> OcrResult: |
|
|
"""Run OCR on raw image bytes (PNG/JPEG/etc).""" |
|
|
from io import BytesIO |
|
|
from PIL import Image as PILImage |
|
|
img = PILImage.open(BytesIO(image_bytes)) |
|
|
return self.recognize_pil(img) |
|
|
|
|
|
def recognize_numpy(self, img_rgb: np.ndarray) -> OcrResult: |
|
|
"""Run OCR on a numpy array (H, W, 3) in RGB format.""" |
|
|
if img_rgb.ndim != 3 or img_rgb.shape[2] != 3: |
|
|
return OcrResult(error="Expected (H, W, 3) RGB array") |
|
|
try: |
|
|
return self._run_pipeline(img_rgb) |
|
|
except Exception as e: |
|
|
return OcrResult(error=f"Pipeline error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
def _run_pipeline(self, img_rgb: np.ndarray) -> OcrResult: |
|
|
"""Full OCR pipeline: detect → crop → scriptID → recognize → group.""" |
|
|
h, w = img_rgb.shape[:2] |
|
|
|
|
|
|
|
|
quads, scale = self._detect(img_rgb) |
|
|
|
|
|
if not quads: |
|
|
return OcrResult(text="", text_angle=0.0, lines=[]) |
|
|
|
|
|
|
|
|
text_angle = _estimate_image_angle(quads) |
|
|
|
|
|
|
|
|
line_results: list[tuple[str, np.ndarray, float, list[float]]] = [] |
|
|
|
|
|
for quad in quads: |
|
|
|
|
|
crop = self._crop_quad(img_rgb, quad) |
|
|
if crop is None or crop.shape[0] < 5 or crop.shape[1] < 5: |
|
|
continue |
|
|
|
|
|
|
|
|
ch, cw = crop.shape[:2] |
|
|
is_vertical = ch > cw * 2 |
|
|
|
|
|
rec_crop = crop |
|
|
if is_vertical: |
|
|
rec_crop = cv2.rotate(crop, cv2.ROTATE_90_COUNTERCLOCKWISE) |
|
|
|
|
|
|
|
|
script_idx = self._identify_script(rec_crop) |
|
|
script_name = SCRIPT_NAMES[script_idx] if script_idx < len(SCRIPT_NAMES) else "Latin" |
|
|
|
|
|
|
|
|
model_idx = SCRIPT_TO_MODEL.get(script_name, 2) |
|
|
|
|
|
|
|
|
text, conf, char_confs = self._recognize(rec_crop, model_idx) |
|
|
|
|
|
|
|
|
if is_vertical and conf < 0.7 and model_idx != 3: |
|
|
text_cjk, conf_cjk, char_confs_cjk = self._recognize(rec_crop, 3) |
|
|
if conf_cjk > conf: |
|
|
text, conf, char_confs = text_cjk, conf_cjk, char_confs_cjk |
|
|
|
|
|
if text.strip(): |
|
|
|
|
|
|
|
|
|
|
|
text_stripped = text.strip() |
|
|
n_chars = len(text_stripped) |
|
|
if n_chars <= 1 and conf < 0.35: |
|
|
continue |
|
|
elif conf < 0.3: |
|
|
continue |
|
|
|
|
|
line_results.append((text, quad, conf, char_confs)) |
|
|
|
|
|
if not line_results: |
|
|
return OcrResult(text="", text_angle=text_angle, lines=[]) |
|
|
|
|
|
|
|
|
lines: list[OcrLine] = [] |
|
|
for line_text, quad, conf, char_confs in line_results: |
|
|
|
|
|
word_texts = line_text.split() |
|
|
if not word_texts: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
words = self._split_into_words(word_texts, quad, conf, char_confs) |
|
|
|
|
|
line_bbox = BoundingRect( |
|
|
x1=float(quad[0][0]), y1=float(quad[0][1]), |
|
|
x2=float(quad[1][0]), y2=float(quad[1][1]), |
|
|
x3=float(quad[2][0]), y3=float(quad[2][1]), |
|
|
x4=float(quad[3][0]), y4=float(quad[3][1]), |
|
|
) |
|
|
|
|
|
full_line_text = " ".join(w.text for w in words) |
|
|
lines.append(OcrLine(text=full_line_text, bounding_rect=line_bbox, words=words)) |
|
|
|
|
|
|
|
|
lines.sort(key=lambda l: ( |
|
|
(l.bounding_rect.y1 + l.bounding_rect.y3) / 2 |
|
|
if l.bounding_rect else 0 |
|
|
)) |
|
|
|
|
|
full_text = "\n".join(line.text for line in lines) |
|
|
|
|
|
return OcrResult(text=full_text, text_angle=text_angle, lines=lines) |
|
|
|
|
|
|
|
|
|
|
|
def _detect(self, img_rgb: np.ndarray) -> tuple[list[np.ndarray], float]: |
|
|
"""Run PixelLink detector and decode bounding boxes. |
|
|
|
|
|
Scaling: FPN-style — scale shortest side to 800, cap longest at 1333. |
|
|
Two-phase detection: |
|
|
- Phase 1: FPN3 (stride=8) + FPN4 (stride=16) — primary detections |
|
|
- Phase 2: FPN2 (stride=4) — supplementary for small text ("I", "...") |
|
|
Only keeps novel FPN2 detections that don't overlap with primary. |
|
|
NMS: IoU threshold 0.2 (from DLL protobuf config). |
|
|
|
|
|
Returns (list_of_quads, scale_factor). |
|
|
""" |
|
|
h, w = img_rgb.shape[:2] |
|
|
|
|
|
|
|
|
short_side = min(h, w) |
|
|
long_side = max(h, w) |
|
|
scale = _DET_TARGET_SHORT / short_side |
|
|
if long_side * scale > _DET_MAX_LONG: |
|
|
scale = _DET_MAX_LONG / long_side |
|
|
scale = min(scale, 6.0) |
|
|
|
|
|
dh = (int(h * scale) + 31) // 32 * 32 |
|
|
dw = (int(w * scale) + 31) // 32 * 32 |
|
|
|
|
|
img_resized = cv2.resize(img_rgb, (dw, dh), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
img_bgr = img_resized[:, :, ::-1].astype(np.float32) - _DET_MEAN |
|
|
data = img_bgr.transpose(2, 0, 1)[np.newaxis] |
|
|
im_info = np.array([[dh, dw, scale]], dtype=np.float32) |
|
|
|
|
|
|
|
|
sess = self._get_detector() |
|
|
outputs = sess.run(None, {"data": data, "im_info": im_info}) |
|
|
output_names = [o.name for o in sess.get_outputs()] |
|
|
out_dict = dict(zip(output_names, outputs)) |
|
|
|
|
|
|
|
|
primary_quads: list[np.ndarray] = [] |
|
|
|
|
|
for level, stride in [("fpn3", 8), ("fpn4", 16)]: |
|
|
min_area_scaled = _MIN_AREA * (scale ** 2) |
|
|
|
|
|
for orientation in ("hori", "vert"): |
|
|
scores_key = f"scores_{orientation}_{level}" |
|
|
links_key = f"link_scores_{orientation}_{level}" |
|
|
deltas_key = f"bbox_deltas_{orientation}_{level}" |
|
|
|
|
|
if scores_key not in out_dict: |
|
|
continue |
|
|
|
|
|
pixel_scores = out_dict[scores_key][0, 0] |
|
|
if orientation == "vert" and pixel_scores.max() <= _PIXEL_SCORE_THRESH: |
|
|
continue |
|
|
|
|
|
link_scores = out_dict[links_key][0] |
|
|
bbox_deltas = out_dict[deltas_key][0] |
|
|
|
|
|
quads = _pixellink_decode( |
|
|
pixel_scores, link_scores, bbox_deltas, stride, |
|
|
pixel_thresh=_PIXEL_SCORE_THRESH, |
|
|
min_area=min_area_scaled, |
|
|
) |
|
|
primary_quads.extend(quads) |
|
|
|
|
|
primary_quads = _nms_quads(primary_quads, iou_thresh=_NMS_IOU_THRESH) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fpn2_quads: list[np.ndarray] = [] |
|
|
min_area_scaled = _MIN_AREA * (scale ** 2) |
|
|
|
|
|
for orientation in ("hori", "vert"): |
|
|
scores_key = f"scores_{orientation}_fpn2" |
|
|
links_key = f"link_scores_{orientation}_fpn2" |
|
|
deltas_key = f"bbox_deltas_{orientation}_fpn2" |
|
|
|
|
|
if scores_key not in out_dict: |
|
|
continue |
|
|
|
|
|
pixel_scores = out_dict[scores_key][0, 0] |
|
|
if pixel_scores.max() <= 0.85: |
|
|
continue |
|
|
|
|
|
link_scores = out_dict[links_key][0] |
|
|
bbox_deltas = out_dict[deltas_key][0] |
|
|
|
|
|
quads = _pixellink_decode( |
|
|
pixel_scores, link_scores, bbox_deltas, 4, |
|
|
pixel_thresh=0.85, |
|
|
min_area=min_area_scaled, |
|
|
min_component_pixels=5, |
|
|
) |
|
|
|
|
|
for q in quads: |
|
|
|
|
|
overlaps = any(_quad_iou(q, p) > 0.1 for p in primary_quads) |
|
|
if not overlaps: |
|
|
fpn2_quads.append(q) |
|
|
|
|
|
|
|
|
all_quads = primary_quads + fpn2_quads |
|
|
if fpn2_quads: |
|
|
all_quads = _nms_quads(all_quads, iou_thresh=_NMS_IOU_THRESH) |
|
|
|
|
|
|
|
|
for i in range(len(all_quads)): |
|
|
all_quads[i] = all_quads[i] / scale |
|
|
|
|
|
return all_quads, scale |
|
|
|
|
|
|
|
|
|
|
|
def _identify_script(self, crop_rgb: np.ndarray) -> int: |
|
|
"""Identify script of a cropped text region. |
|
|
|
|
|
Returns script index (0=Latin, 1=CJK, ..., 9=Unknown). |
|
|
""" |
|
|
sess = self._get_script_id() |
|
|
|
|
|
|
|
|
data = self._preprocess_recognizer(crop_rgb) |
|
|
|
|
|
outputs = sess.run(None, {"data": data}) |
|
|
|
|
|
script_scores = outputs[3] |
|
|
script_idx = int(np.argmax(script_scores.flatten()[:10])) |
|
|
return script_idx |
|
|
|
|
|
|
|
|
|
|
|
def _recognize( |
|
|
self, crop_rgb: np.ndarray, model_idx: int |
|
|
) -> tuple[str, float, list[float]]: |
|
|
"""Recognize text in a cropped region using the specified model. |
|
|
|
|
|
Returns (text, confidence, per_char_confidences). |
|
|
""" |
|
|
sess = self._get_recognizer(model_idx) |
|
|
idx2char, blank_idx = self._get_char_map(model_idx) |
|
|
|
|
|
data = self._preprocess_recognizer(crop_rgb) |
|
|
h, w = data.shape[2], data.shape[3] |
|
|
seq_lengths = np.array([w // 4], dtype=np.int32) |
|
|
|
|
|
logprobs = sess.run(None, {"data": data, "seq_lengths": seq_lengths})[0] |
|
|
text, conf, char_confs = ctc_greedy_decode(logprobs, idx2char, blank_idx) |
|
|
return text, conf, char_confs |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _split_into_words( |
|
|
word_texts: list[str], |
|
|
quad: np.ndarray, |
|
|
confidence: float, |
|
|
char_confs: list[float] | None = None, |
|
|
) -> list[OcrWord]: |
|
|
"""Split a line into word-level OcrWord objects with estimated bboxes. |
|
|
|
|
|
Distributes the line quad proportionally by character count. |
|
|
Per-word confidence is computed from character-level CTC confidences. |
|
|
""" |
|
|
if not word_texts: |
|
|
return [] |
|
|
|
|
|
|
|
|
total_chars = sum(len(w) for w in word_texts) + len(word_texts) - 1 |
|
|
if total_chars <= 0: |
|
|
total_chars = 1 |
|
|
|
|
|
|
|
|
word_confidences: list[float] = [] |
|
|
if char_confs and len(char_confs) >= sum(len(w) for w in word_texts): |
|
|
idx = 0 |
|
|
for word_text in word_texts: |
|
|
wc = char_confs[idx:idx + len(word_text)] |
|
|
word_confidences.append(float(np.mean(wc)) if wc else confidence) |
|
|
idx += len(word_text) |
|
|
|
|
|
if idx < len(char_confs): |
|
|
idx += 1 |
|
|
else: |
|
|
word_confidences = [confidence] * len(word_texts) |
|
|
|
|
|
|
|
|
top_start, top_end = quad[0], quad[1] |
|
|
bot_start, bot_end = quad[3], quad[2] |
|
|
|
|
|
words: list[OcrWord] = [] |
|
|
char_pos = 0 |
|
|
|
|
|
for i, word_text in enumerate(word_texts): |
|
|
t_start = char_pos / total_chars |
|
|
t_end = (char_pos + len(word_text)) / total_chars |
|
|
|
|
|
|
|
|
tl = top_start + (top_end - top_start) * t_start |
|
|
tr = top_start + (top_end - top_start) * t_end |
|
|
bl = bot_start + (bot_end - bot_start) * t_start |
|
|
br = bot_start + (bot_end - bot_start) * t_end |
|
|
|
|
|
bbox = BoundingRect( |
|
|
x1=float(tl[0]), y1=float(tl[1]), |
|
|
x2=float(tr[0]), y2=float(tr[1]), |
|
|
x3=float(br[0]), y3=float(br[1]), |
|
|
x4=float(bl[0]), y4=float(bl[1]), |
|
|
) |
|
|
words.append(OcrWord( |
|
|
text=word_text, |
|
|
bounding_rect=bbox, |
|
|
confidence=word_confidences[i], |
|
|
)) |
|
|
|
|
|
char_pos += len(word_text) + 1 |
|
|
|
|
|
return words |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _preprocess_recognizer(img_rgb: np.ndarray) -> np.ndarray: |
|
|
"""Preprocess image for recognizer/scriptID input. |
|
|
|
|
|
Process: Resize height to 60px → RGB / 255.0 → NCHW float32. |
|
|
""" |
|
|
h, w = img_rgb.shape[:2] |
|
|
target_h = _REC_TARGET_H |
|
|
scale = target_h / h |
|
|
new_w = max(int(w * scale), _REC_MIN_WIDTH) |
|
|
new_w = (new_w + 3) // 4 * 4 |
|
|
|
|
|
resized = cv2.resize(img_rgb, (new_w, target_h), interpolation=cv2.INTER_LINEAR) |
|
|
data = resized.astype(np.float32) / 255.0 |
|
|
data = data.transpose(2, 0, 1)[np.newaxis] |
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _crop_quad( |
|
|
img_rgb: np.ndarray, |
|
|
quad: np.ndarray, |
|
|
padding_ratio: float = 0.15, |
|
|
) -> np.ndarray | None: |
|
|
"""Crop a text region from the image using the bounding quad. |
|
|
|
|
|
Uses axis-aligned rectangle crop with padding (matching DLL behavior). |
|
|
Falls back to perspective transform only for heavily rotated text (>15°). |
|
|
|
|
|
Args: |
|
|
img_rgb: Source image (H, W, 3). |
|
|
quad: 4 corners as (4, 2) array. |
|
|
padding_ratio: How much to expand (fraction of height). |
|
|
|
|
|
Returns: |
|
|
Cropped RGB image or None. |
|
|
""" |
|
|
try: |
|
|
img_h, img_w = img_rgb.shape[:2] |
|
|
|
|
|
|
|
|
dx = quad[1][0] - quad[0][0] |
|
|
dy = quad[1][1] - quad[0][1] |
|
|
angle = abs(math.atan2(dy, dx)) * 180 / math.pi |
|
|
|
|
|
|
|
|
if angle < 15: |
|
|
|
|
|
x_min = float(quad[:, 0].min()) |
|
|
x_max = float(quad[:, 0].max()) |
|
|
y_min = float(quad[:, 1].min()) |
|
|
y_max = float(quad[:, 1].max()) |
|
|
|
|
|
box_h = y_max - y_min |
|
|
box_w = x_max - x_min |
|
|
|
|
|
if box_h < 3 or box_w < 5: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
pad_h = max(box_h * padding_ratio, 3) |
|
|
pad_w = max(box_h * 0.25, 3) |
|
|
|
|
|
y1 = max(0, int(y_min - pad_h)) |
|
|
y2 = min(img_h, int(y_max + pad_h)) |
|
|
x1 = max(0, int(x_min - pad_w)) |
|
|
x2 = min(img_w, int(x_max + pad_w)) |
|
|
|
|
|
if y2 - y1 < 3 or x2 - x1 < 5: |
|
|
return None |
|
|
|
|
|
return img_rgb[y1:y2, x1:x2].copy() |
|
|
|
|
|
|
|
|
w1 = np.linalg.norm(quad[1] - quad[0]) |
|
|
w2 = np.linalg.norm(quad[2] - quad[3]) |
|
|
h1 = np.linalg.norm(quad[3] - quad[0]) |
|
|
h2 = np.linalg.norm(quad[2] - quad[1]) |
|
|
|
|
|
target_w = int(max(w1, w2)) |
|
|
target_h = int(max(h1, h2)) |
|
|
|
|
|
if target_w < 5 or target_h < 3: |
|
|
return None |
|
|
|
|
|
|
|
|
pad = max(h1, h2) * padding_ratio |
|
|
|
|
|
top_dir = quad[0] - quad[3] |
|
|
if np.linalg.norm(top_dir) > 0: |
|
|
top_dir = top_dir / np.linalg.norm(top_dir) |
|
|
else: |
|
|
top_dir = np.array([0, -1], dtype=np.float32) |
|
|
|
|
|
left_dir = quad[0] - quad[1] |
|
|
if np.linalg.norm(left_dir) > 0: |
|
|
left_dir = left_dir / np.linalg.norm(left_dir) |
|
|
else: |
|
|
left_dir = np.array([-1, 0], dtype=np.float32) |
|
|
|
|
|
expanded = quad.copy().astype(np.float32) |
|
|
expanded[0] = quad[0] + top_dir * pad + left_dir * pad * 0.3 |
|
|
expanded[1] = quad[1] + top_dir * pad - left_dir * pad * 0.3 |
|
|
expanded[2] = quad[2] - top_dir * pad - left_dir * pad * 0.3 |
|
|
expanded[3] = quad[3] - top_dir * pad + left_dir * pad * 0.3 |
|
|
|
|
|
expanded[:, 0] = np.clip(expanded[:, 0], 0, img_w - 1) |
|
|
expanded[:, 1] = np.clip(expanded[:, 1], 0, img_h - 1) |
|
|
|
|
|
w1 = np.linalg.norm(expanded[1] - expanded[0]) |
|
|
w2 = np.linalg.norm(expanded[2] - expanded[3]) |
|
|
h1 = np.linalg.norm(expanded[3] - expanded[0]) |
|
|
h2 = np.linalg.norm(expanded[2] - expanded[1]) |
|
|
target_w = int(max(w1, w2)) |
|
|
target_h = int(max(h1, h2)) |
|
|
|
|
|
if target_w < 5 or target_h < 3: |
|
|
return None |
|
|
|
|
|
dst = np.array([ |
|
|
[0, 0], |
|
|
[target_w - 1, 0], |
|
|
[target_w - 1, target_h - 1], |
|
|
[0, target_h - 1], |
|
|
], dtype=np.float32) |
|
|
|
|
|
M = cv2.getPerspectiveTransform(expanded.astype(np.float32), dst) |
|
|
crop = cv2.warpPerspective( |
|
|
img_rgb, M, (target_w, target_h), |
|
|
flags=cv2.INTER_LINEAR, |
|
|
borderMode=cv2.BORDER_REPLICATE, |
|
|
) |
|
|
return crop |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def _find_model(self, model_idx: int) -> Path: |
|
|
"""Find ONNX model file by index. Checks unlocked dir first for models 11-33.""" |
|
|
if model_idx >= 11 and self._unlocked_dir.exists(): |
|
|
matches = list(self._unlocked_dir.glob(f"model_{model_idx:02d}_*")) |
|
|
if matches: |
|
|
return matches[0] |
|
|
matches = list(self._models_dir.glob(f"model_{model_idx:02d}_*")) |
|
|
if not matches: |
|
|
raise FileNotFoundError(f"Model file not found for index {model_idx}") |
|
|
return matches[0] |
|
|
|
|
|
def _get_detector(self) -> ort.InferenceSession: |
|
|
"""Get or create detector session.""" |
|
|
if self._detector is None: |
|
|
path = self._find_model(0) |
|
|
self._detector = ort.InferenceSession( |
|
|
str(path), sess_options=self._sess_opts, providers=self._providers |
|
|
) |
|
|
return self._detector |
|
|
|
|
|
def _get_script_id(self) -> ort.InferenceSession: |
|
|
"""Get or create ScriptID session.""" |
|
|
if self._script_id is None: |
|
|
path = self._find_model(1) |
|
|
self._script_id = ort.InferenceSession( |
|
|
str(path), sess_options=self._sess_opts, providers=self._providers |
|
|
) |
|
|
return self._script_id |
|
|
|
|
|
def _get_recognizer(self, model_idx: int) -> ort.InferenceSession: |
|
|
"""Get or create recognizer session.""" |
|
|
if model_idx not in self._recognizers: |
|
|
path = self._find_model(model_idx) |
|
|
self._recognizers[model_idx] = ort.InferenceSession( |
|
|
str(path), sess_options=self._sess_opts, providers=self._providers |
|
|
) |
|
|
return self._recognizers[model_idx] |
|
|
|
|
|
def _get_char_map(self, model_idx: int) -> tuple[dict[int, str], int]: |
|
|
"""Get or load character map for model.""" |
|
|
if model_idx not in self._char_maps: |
|
|
info = MODEL_REGISTRY.get(model_idx) |
|
|
if not info or not info[2]: |
|
|
raise ValueError(f"No char2ind file for model {model_idx}") |
|
|
char_path = self._config_dir / info[2] |
|
|
self._char_maps[model_idx] = load_char_map(char_path) |
|
|
return self._char_maps[model_idx] |
|
|
|
|
|
def _get_line_layout(self) -> ort.InferenceSession | None: |
|
|
"""Get or create LineLayout session (model 33). Returns None if unavailable.""" |
|
|
if self._line_layout is None: |
|
|
try: |
|
|
path = self._find_model(33) |
|
|
self._line_layout = ort.InferenceSession( |
|
|
str(path), sess_options=self._sess_opts, providers=self._providers |
|
|
) |
|
|
except FileNotFoundError: |
|
|
return None |
|
|
return self._line_layout |
|
|
|
|
|
def _run_line_layout(self, crop_rgb: np.ndarray) -> float | None: |
|
|
"""Run LineLayout model to get line boundary score. |
|
|
|
|
|
Args: |
|
|
crop_rgb: Cropped text line image. |
|
|
|
|
|
Returns: |
|
|
Line layout score (higher = more confident this is a complete line), |
|
|
or None if LineLayout model unavailable. |
|
|
""" |
|
|
sess = self._get_line_layout() |
|
|
if sess is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
data = self._preprocess_recognizer(crop_rgb) |
|
|
outputs = sess.run(None, {"data": data}) |
|
|
|
|
|
score = float(outputs[1].flatten()[0]) |
|
|
return score |
|
|
except Exception: |
|
|
return None |
|
|
|