| """OCR 文本提取:PP-OCRv5 ONNX 推理 + PP-DocLayout-S sim-cut ONNX 版面分析""" |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from enum import Enum |
| import math |
| from pathlib import Path |
| from threading import RLock |
| from typing import List, Optional, Sequence, Tuple |
|
|
| import cv2 |
| import numpy as np |
| from PIL import Image, ImageEnhance |
|
|
| from ..preprocess.scene_classifier import Scene, SceneResult |
|
|
| PERCEPTION_DIR = Path(__file__).resolve().parent |
| PROJECT_ROOT = PERCEPTION_DIR.parents[1] |
|
|
| AXMODEL_DIR = PROJECT_ROOT / "axmodel" |
| DET_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "det_npu1.axmodel" |
| CLS_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "cls_npu1.axmodel" |
| REC_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "rec_npu1.axmodel" |
| DICT_PATH = PERCEPTION_DIR / "dict" / "ppocrv5_dict.txt" |
| LAYOUT_MODEL_PATH = AXMODEL_DIR / "ppstructurev3" / "ppstructure_npu1.axmodel" |
|
|
|
|
| DET_INPUT_SIZE = 960 |
| CLS_INPUT_SIZE = (80, 160) |
| REC_INPUT_SIZE = (48, 320) |
| LAYOUT_INPUT_SIZE = (480, 480) |
| DET_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32) |
| DET_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32) |
| CLS_MEAN = np.array([127.5, 127.5, 127.5], dtype=np.float32) |
| CLS_STD = np.array([127.5, 127.5, 127.5], dtype=np.float32) |
| REC_MEAN = np.array([127.5, 127.5, 127.5], dtype=np.float32) |
| REC_STD = np.array([127.5, 127.5, 127.5], dtype=np.float32) |
| LAYOUT_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| LAYOUT_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
|
|
| DET_BIN_THRESHOLD = 0.30 |
| DET_MIN_COMPONENT_AREA = 16 |
| DET_SCORE_THRESHOLD = 0.35 |
| CLS_ROTATE_THRESHOLD = 0.90 |
| LAYOUT_SCORE_THRESHOLD = 0.35 |
| LAYOUT_DEDUP_IOU = 0.85 |
|
|
| LAYOUT_LABELS = [ |
| "paragraph_title", |
| "image", |
| "text", |
| "number", |
| "abstract", |
| "content", |
| "figure_title", |
| "formula", |
| "table", |
| "table_title", |
| "reference", |
| "doc_title", |
| "footnote", |
| "header", |
| "algorithm", |
| "footer", |
| "seal", |
| "chart_title", |
| "chart", |
| "formula_number", |
| "header_image", |
| "footer_image", |
| "aside_text", |
| ] |
|
|
| LAYOUT_TYPE_MAP = { |
| "paragraph_title": "title", |
| "doc_title": "title", |
| "text": "text", |
| "content": "text", |
| "reference": "text", |
| "abstract": "text", |
| "number": "text", |
| "aside_text": "text", |
| "figure_title": "figure_caption", |
| "table_title": "figure_caption", |
| "chart_title": "figure_caption", |
| "formula_number": "figure_caption", |
| "image": "figure", |
| "chart": "figure", |
| "header_image": "figure", |
| "footer_image": "figure", |
| "table": "table", |
| "formula": "figure", |
| "header": "unknown", |
| "footer": "unknown", |
| "seal": "unknown", |
| "algorithm": "text", |
| "footnote": "text", |
| } |
|
|
| LAYOUT_ENABLED_LABELS = { |
| "paragraph_title", |
| "image", |
| "text", |
| "number", |
| "abstract", |
| "content", |
| "figure_title", |
| "formula", |
| "table", |
| "table_title", |
| "reference", |
| "doc_title", |
| "footnote", |
| "algorithm", |
| "chart_title", |
| "chart", |
| } |
|
|
| _ort = None |
| _runtime_lock = RLock() |
| _det_sess = None |
| _cls_sess = None |
| _rec_sess = None |
| _layout_sess = None |
| _char_list = None |
| _pyclipper = None |
|
|
|
|
| def _get_ort(): |
| global _ort |
| if _ort is None: |
| try: |
| |
| import axengine as ort |
| except ImportError as exc: |
| raise RuntimeError( |
| "onnxruntime is required for PP-OCRv5 ONNX inference. " |
| "Please install it in the runtime environment." |
| ) from exc |
| _ort = ort |
| return _ort |
|
|
|
|
| def _get_session(path: Path): |
| with _runtime_lock: |
| ort = _get_ort() |
| return ort.InferenceSession(str(path), providers=["AxEngineExecutionProvider"]) |
|
|
|
|
| def _get_det_sess(): |
| global _det_sess |
| if _det_sess is None: |
| _det_sess = _get_session(DET_MODEL_PATH) |
| return _det_sess |
|
|
|
|
| def _get_cls_sess(): |
| global _cls_sess |
| if _cls_sess is None: |
| _cls_sess = _get_session(CLS_MODEL_PATH) |
| return _cls_sess |
|
|
|
|
| def _get_rec_sess(): |
| global _rec_sess |
| if _rec_sess is None: |
| _rec_sess = _get_session(REC_MODEL_PATH) |
| return _rec_sess |
|
|
|
|
| def _get_layout_sess(): |
| global _layout_sess |
| if _layout_sess is None: |
| _layout_sess = _get_session(LAYOUT_MODEL_PATH) |
| return _layout_sess |
|
|
|
|
| def _get_char_list() -> List[str]: |
| global _char_list |
| if _char_list is None: |
| chars = DICT_PATH.read_text(encoding="utf-8").splitlines() |
| |
| _char_list = ["blank"] + chars + [" "] |
| return _char_list |
|
|
|
|
| def _get_pyclipper(): |
| global _pyclipper |
| if _pyclipper is None: |
| try: |
| import pyclipper |
| except ImportError: |
| _pyclipper = False |
| else: |
| _pyclipper = pyclipper |
| return _pyclipper |
|
|
|
|
| @dataclass |
| class TextBlock: |
| text: str |
| score: float |
| bbox: List[int] |
| region_type: str = "unknown" |
|
|
|
|
| class OCRTextState(str, Enum): |
| OK = "ok" |
| LOW_QUALITY = "low_quality" |
| MISSING_EXPECTED_TEXT = "missing_expected_text" |
| MISSING_UNCERTAIN = "missing_uncertain" |
|
|
|
|
| @dataclass |
| class OCRResult: |
| blocks: List[TextBlock] = field(default_factory=list) |
| avg_score: float = 0.0 |
| low_confidence: bool = False |
| text_state: OCRTextState = OCRTextState.OK |
| figure_images: List[Image.Image] = field(default_factory=list) |
| figure_bboxes: List[List[int]] = field(default_factory=list) |
|
|
|
|
| def _resize_no_pad(arr_rgb: np.ndarray, target_h: int, target_w: int, mean: np.ndarray, std: np.ndarray) -> np.ndarray: |
| resized = np.array(Image.fromarray(arr_rgb).resize((target_w, target_h), Image.BICUBIC)).astype(np.float32).transpose(2, 0, 1)[np.newaxis] |
| |
| |
| return resized |
|
|
|
|
| def _get_mini_boxes(contour: np.ndarray) -> Tuple[List[np.ndarray], float]: |
| bounding_box = cv2.minAreaRect(contour) |
| points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) |
|
|
| if points[1][1] > points[0][1]: |
| index_1, index_4 = 0, 1 |
| else: |
| index_1, index_4 = 1, 0 |
|
|
| if points[3][1] > points[2][1]: |
| index_2, index_3 = 2, 3 |
| else: |
| index_2, index_3 = 3, 2 |
|
|
| box = [points[index_1], points[index_2], points[index_3], points[index_4]] |
| return box, min(bounding_box[1]) |
|
|
|
|
| def _box_score_fast(bitmap: np.ndarray, box: np.ndarray) -> float: |
| h, w = bitmap.shape[:2] |
| scored_box = box.copy() |
| xmin = np.clip(np.floor(scored_box[:, 0].min()).astype("int32"), 0, w - 1) |
| xmax = np.clip(np.ceil(scored_box[:, 0].max()).astype("int32"), 0, w - 1) |
| ymin = np.clip(np.floor(scored_box[:, 1].min()).astype("int32"), 0, h - 1) |
| ymax = np.clip(np.ceil(scored_box[:, 1].max()).astype("int32"), 0, h - 1) |
|
|
| mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) |
| scored_box[:, 0] = scored_box[:, 0] - xmin |
| scored_box[:, 1] = scored_box[:, 1] - ymin |
| cv2.fillPoly(mask, scored_box.reshape(1, -1, 2).astype("int32"), 1) |
| return float(cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]) |
|
|
|
|
| def _unclip(box: np.ndarray, unclip_ratio: float) -> np.ndarray: |
| area = float(abs(cv2.contourArea(box.astype(np.float32)))) |
| length = float(cv2.arcLength(box.astype(np.float32), True)) |
| if area <= 1e-6 or length <= 1e-6: |
| return np.array([box], dtype=np.float32) |
|
|
| pyclipper = _get_pyclipper() |
| if not pyclipper: |
| return np.array([box], dtype=np.float32) |
|
|
| distance = area * unclip_ratio / length |
| offset = pyclipper.PyclipperOffset() |
| offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) |
| expanded = offset.Execute(distance) |
| if not expanded: |
| return np.array([box], dtype=np.float32) |
| return np.array(expanded) |
|
|
|
|
| def _boxes_from_bitmap(pred: np.ndarray, bitmap: np.ndarray, dest_width: int, dest_height: int) -> Tuple[np.ndarray, List[float]]: |
| box_thresh = 0.6 |
| max_candidates = 1000 |
| unclip_ratio = 1.5 |
| min_size = 3 |
|
|
| height, width = bitmap.shape |
| outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
| if len(outs) == 3: |
| contours = outs[1] |
| else: |
| contours = outs[0] |
|
|
| num_contours = min(len(contours), max_candidates) |
| boxes: List[np.ndarray] = [] |
| scores: List[float] = [] |
| for index in range(num_contours): |
| contour = contours[index] |
| points, sside = _get_mini_boxes(contour) |
| if sside < min_size: |
| continue |
| points = np.array(points) |
| score = _box_score_fast(pred, points.reshape(-1, 2)) |
| if score < box_thresh: |
| continue |
|
|
| box = _unclip(points, unclip_ratio).reshape(-1, 1, 2) |
| box, sside = _get_mini_boxes(box) |
| if sside < min_size + 2: |
| continue |
| box = np.array(box) |
| box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) |
| box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) |
| boxes.append(box.astype("int32")) |
| scores.append(float(score)) |
| if not boxes: |
| return np.zeros((0, 4, 2), dtype="int32"), scores |
| return np.array(boxes, dtype="int32"), scores |
|
|
|
|
| def _order_points_clockwise(points: np.ndarray) -> np.ndarray: |
| rect = np.zeros((4, 2), dtype="float32") |
| point_sums = points.sum(axis=1) |
| rect[0] = points[np.argmin(point_sums)] |
| rect[2] = points[np.argmax(point_sums)] |
| remaining = np.delete(points, (np.argmin(point_sums), np.argmax(point_sums)), axis=0) |
| diffs = np.diff(np.array(remaining), axis=1) |
| rect[1] = remaining[np.argmin(diffs)] |
| rect[3] = remaining[np.argmax(diffs)] |
| return rect |
|
|
|
|
| def _clip_det_res(points: np.ndarray, img_height: int, img_width: int) -> np.ndarray: |
| for index in range(points.shape[0]): |
| points[index, 0] = int(min(max(points[index, 0], 0), img_width - 1)) |
| points[index, 1] = int(min(max(points[index, 1], 0), img_height - 1)) |
| return points |
|
|
|
|
| def _filter_tag_det_res(dt_boxes: np.ndarray, image_shape: Sequence[int]) -> np.ndarray: |
| img_height, img_width = image_shape[0:2] |
| filtered_boxes = [] |
| for box in dt_boxes: |
| current_box = np.array(box) if isinstance(box, list) else box.copy() |
| current_box = _order_points_clockwise(current_box) |
| current_box = _clip_det_res(current_box, img_height, img_width) |
| rect_width = int(np.linalg.norm(current_box[0] - current_box[1])) |
| rect_height = int(np.linalg.norm(current_box[0] - current_box[3])) |
| if rect_width <= 3 or rect_height <= 3: |
| continue |
| filtered_boxes.append(current_box) |
| if not filtered_boxes: |
| return np.zeros((0, 4, 2), dtype=np.float32) |
| return np.array(filtered_boxes) |
|
|
|
|
| def _sorted_boxes(dt_boxes: np.ndarray) -> List[np.ndarray]: |
| num_boxes = dt_boxes.shape[0] |
| ordered = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) |
| boxes = list(ordered) |
| for i in range(num_boxes - 1): |
| for j in range(i, -1, -1): |
| if abs(boxes[j + 1][0][1] - boxes[j][0][1]) < 10 and boxes[j + 1][0][0] < boxes[j][0][0]: |
| boxes[j], boxes[j + 1] = boxes[j + 1], boxes[j] |
| else: |
| break |
| return boxes |
|
|
|
|
| def _det_postprocess(outs_dict: dict, shape_list: np.ndarray) -> List[dict]: |
| pred = outs_dict["maps"] |
| pred = pred[:, 0, :, :] |
| segmentation = pred > DET_BIN_THRESHOLD |
| boxes_batch = [] |
| for batch_index in range(pred.shape[0]): |
| src_h, src_w = int(shape_list[batch_index][0]), int(shape_list[batch_index][1]) |
| mask = segmentation[batch_index] |
| boxes, _ = _boxes_from_bitmap(pred[batch_index], mask, src_w, src_h) |
| boxes_batch.append({"points": boxes}) |
| return boxes_batch |
|
|
|
|
| def _detect_boxes(arr_bgr: np.ndarray) -> List[np.ndarray]: |
| orig_h, orig_w = arr_bgr.shape[:2] |
| det_input = cv2.resize(arr_bgr, (DET_INPUT_SIZE, DET_INPUT_SIZE)).astype(np.float32) |
| det_input = det_input.transpose(2, 0, 1)[np.newaxis] |
| shape_list = np.array( |
| [[orig_h, orig_w, DET_INPUT_SIZE / max(orig_h, 1), DET_INPUT_SIZE / max(orig_w, 1)]], |
| dtype=np.float32, |
| ) |
| sess = _get_det_sess() |
| det_out = sess.run(None, {sess.get_inputs()[0].name: det_input}) |
| post_result = _det_postprocess({"maps": det_out[0]}, shape_list) |
| dt_boxes = _filter_tag_det_res(post_result[0]["points"], arr_bgr.shape) |
| if dt_boxes.size == 0: |
| return [] |
| return _sorted_boxes(dt_boxes) |
|
|
|
|
| def _get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray: |
| if len(points) != 4: |
| raise ValueError("shape of points must be 4*2") |
| img_crop_width = int(max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]))) |
| img_crop_height = int(max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]))) |
| pts_std = np.float32( |
| [ |
| [0, 0], |
| [img_crop_width, 0], |
| [img_crop_width, img_crop_height], |
| [0, img_crop_height], |
| ] |
| ) |
| transform = cv2.getPerspectiveTransform(points.astype(np.float32), pts_std) |
| dst_img = cv2.warpPerspective( |
| img, |
| transform, |
| (img_crop_width, img_crop_height), |
| borderMode=cv2.BORDER_REPLICATE, |
| flags=cv2.INTER_CUBIC, |
| ) |
| dst_img_height, dst_img_width = dst_img.shape[:2] |
| if dst_img_width == 0 or dst_img_height == 0: |
| return dst_img |
| if dst_img_height * 1.0 / dst_img_width >= 1.5: |
| dst_img = np.rot90(dst_img) |
| return dst_img |
|
|
|
|
| def _resize_norm_img(img: np.ndarray, shape: Tuple[int, int, int]) -> np.ndarray: |
| img_h, img_w = img.shape[:2] |
| img_c, target_h, target_w = shape |
| ratio = img_w / float(max(img_h, 1)) |
| if math.ceil(target_h * ratio) > target_w: |
| resized_w = target_w |
| else: |
| resized_w = int(math.ceil(target_h * ratio)) |
| resized_image = cv2.resize(img, (resized_w, target_h)).astype(np.float32) |
| resized_image = resized_image.transpose((2, 0, 1)) |
| padding_im = np.zeros((img_c, target_h, target_w), dtype=np.float32) |
| padding_im[:, :, 0:resized_w] = resized_image |
| return padding_im |
|
|
|
|
| def _classify_orientation(crop_bgr: np.ndarray) -> np.ndarray: |
| if crop_bgr.size == 0: |
| return crop_bgr |
| cls_input = _resize_norm_img(crop_bgr, (3, CLS_INPUT_SIZE[0], CLS_INPUT_SIZE[1]))[np.newaxis].copy() |
| sess = _get_cls_sess() |
| probs = sess.run(None, {sess.get_inputs()[0].name: cls_input})[0] |
| pred_idx = int(np.argmax(probs[0])) |
| pred_score = float(probs[0, pred_idx]) |
| if pred_idx == 1 and pred_score > CLS_ROTATE_THRESHOLD: |
| return cv2.rotate(crop_bgr, cv2.ROTATE_180) |
| return crop_bgr |
|
|
|
|
| def _preprocess_rec(crop_bgr: np.ndarray) -> np.ndarray: |
| if crop_bgr.size == 0: |
| return np.zeros((1, 3, REC_INPUT_SIZE[0], REC_INPUT_SIZE[1]), dtype=np.float32) |
| return _resize_norm_img(crop_bgr, (3, REC_INPUT_SIZE[0], REC_INPUT_SIZE[1]))[np.newaxis].copy() |
|
|
|
|
| def _decode_text_indices( |
| chars: List[str], |
| text_index: np.ndarray, |
| text_prob: np.ndarray, |
| is_remove_duplicate: bool = False, |
| ) -> List[Tuple[str, float]]: |
| result_list: List[Tuple[str, float]] = [] |
| ignored_tokens = [0] |
| batch_size = len(text_index) |
| for batch_idx in range(batch_size): |
| selection = np.ones(len(text_index[batch_idx]), dtype=bool) |
| if is_remove_duplicate: |
| selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] |
| for ignored_token in ignored_tokens: |
| selection &= text_index[batch_idx] != ignored_token |
| char_list = [chars[text_id] for text_id in text_index[batch_idx][selection] if text_id < len(chars)] |
| conf_list = text_prob[batch_idx][selection] |
| if len(conf_list) == 0: |
| conf_list = np.array([0.0], dtype=np.float32) |
| result_list.append(("".join(char_list), float(np.mean(conf_list).tolist()))) |
| return result_list |
|
|
|
|
| def _decode_rec(logits: np.ndarray) -> Tuple[str, float]: |
| chars = _get_char_list() |
| preds = logits[-1] if isinstance(logits, (tuple, list)) else logits |
| preds = preds.astype(np.float32) |
| preds_idx = preds.argmax(axis=2) |
| preds_prob = preds.max(axis=2) |
| return _decode_text_indices(chars, preds_idx, preds_prob, is_remove_duplicate=True)[0] |
|
|
|
|
| def _recognize_crop(crop_bgr: np.ndarray) -> Tuple[str, float]: |
| oriented = _classify_orientation(crop_bgr) |
| rec_input = _preprocess_rec(oriented) |
| sess = _get_rec_sess() |
| logits = sess.run(None, {sess.get_inputs()[0].name: rec_input})[0] |
| return _decode_rec(logits) |
|
|
|
|
| def _enhance_crop(crop_bgr: np.ndarray) -> np.ndarray: |
| img = Image.fromarray(crop_bgr[:, :, ::-1]) |
| img = ImageEnhance.Contrast(img).enhance(1.5) |
| return np.array(img)[:, :, ::-1] |
|
|
|
|
| def _quad_to_bbox(points: np.ndarray, width: int, height: int) -> List[int]: |
| min_x = max(0, int(np.floor(np.min(points[:, 0])))) |
| min_y = max(0, int(np.floor(np.min(points[:, 1])))) |
| max_x = min(width, int(np.ceil(np.max(points[:, 0])))) |
| max_y = min(height, int(np.ceil(np.max(points[:, 1])))) |
| return [min_x, min_y, max_x, max_y] |
|
|
|
|
| def _iou(a: List[int], b: List[int]) -> float: |
| ax1, ay1, ax2, ay2 = a |
| bx1, by1, bx2, by2 = b |
| ix1, iy1 = max(ax1, bx1), max(ay1, by1) |
| ix2, iy2 = min(ax2, bx2), min(ay2, by2) |
| inter = max(0, ix2 - ix1) * max(0, iy2 - iy1) |
| if inter == 0: |
| return 0.0 |
| area_a = (ax2 - ax1) * (ay2 - ay1) |
| area_b = (bx2 - bx1) * (by2 - by1) |
| return inter / min(area_a, area_b) |
|
|
|
|
| def _clip_bbox(bbox: Sequence[float], width: int, height: int) -> List[int]: |
| x1, y1, x2, y2 = bbox |
| clipped = [ |
| int(round(max(0.0, min(float(width), x1)))), |
| int(round(max(0.0, min(float(height), y1)))), |
| int(round(max(0.0, min(float(width), x2)))), |
| int(round(max(0.0, min(float(height), y2)))), |
| ] |
| return clipped |
|
|
|
|
| def _dedupe_layout_regions(regions: List[dict]) -> List[dict]: |
| kept: List[dict] = [] |
| for region in sorted(regions, key=lambda item: item["score"], reverse=True): |
| duplicate = False |
| for existing in kept: |
| same_type = existing.get("type") == region.get("type") |
| same_label = existing.get("raw_type") == region.get("raw_type") |
| if not same_type: |
| continue |
| if region.get("type") != "figure" and not same_label: |
| continue |
| if _iou(existing["bbox"], region["bbox"]) >= LAYOUT_DEDUP_IOU: |
| duplicate = True |
| break |
| if not duplicate: |
| kept.append(region) |
| kept.sort(key=lambda item: (item["bbox"][1], item["bbox"][0])) |
| return kept |
|
|
|
|
| def _run_ocr_full(arr_rgb: np.ndarray, threshold: float) -> List[TextBlock]: |
| blocks = [] |
| arr_bgr = arr_rgb[:, :, ::-1] |
| height, width = arr_rgb.shape[:2] |
| for dt_box in _detect_boxes(arr_bgr): |
| crop = _get_rotate_crop_image(arr_bgr, np.asarray(dt_box, dtype=np.float32)) |
| if crop.size == 0: |
| continue |
| text, score = _recognize_crop(crop) |
| if score < threshold: |
| retry_text, retry_score = _recognize_crop(_enhance_crop(crop)) |
| if retry_score > score: |
| text, score = retry_text, retry_score |
| bbox = _quad_to_bbox(np.asarray(dt_box, dtype=np.float32), width, height) |
| blocks.append(TextBlock(text=text, score=score, bbox=bbox)) |
| return blocks |
|
|
|
|
| def _assign_region(block: TextBlock, regions: list) -> tuple: |
| """返回 (region_type, region_index)""" |
| best_type, best_iou, best_idx = "unknown", 0.0, len(regions) |
| for i, region in enumerate(regions): |
| bbox = region.get("bbox") |
| if bbox is None: |
| continue |
| score = _iou(block.bbox, bbox) |
| if score > best_iou: |
| best_iou = score |
| best_type = region.get("type", "unknown") |
| best_idx = i |
| if best_iou >= 0.9: |
| return best_type, best_idx |
| return (best_type if best_iou > 0.1 else "unknown"), best_idx |
|
|
|
|
| def _run_structure(arr_rgb: np.ndarray): |
| sess = _get_layout_sess() |
| h, w = arr_rgb.shape[:2] |
| image_input = _resize_no_pad(arr_rgb, LAYOUT_INPUT_SIZE[0], LAYOUT_INPUT_SIZE[1], LAYOUT_MEAN, LAYOUT_STD) |
|
|
| def _append_region(regions: List[dict], cls_id: int, score: float, bbox_values: Sequence[float], scale_to_image: bool): |
| if score < LAYOUT_SCORE_THRESHOLD: |
| return |
| raw_label = LAYOUT_LABELS[cls_id] if 0 <= cls_id < len(LAYOUT_LABELS) else str(cls_id) |
| if raw_label not in LAYOUT_ENABLED_LABELS: |
| return |
| region_type = LAYOUT_TYPE_MAP.get(raw_label, "unknown") |
| x1, y1, x2, y2 = map(float, bbox_values) |
| if x2 < x1: |
| x1, x2 = x2, x1 |
| if y2 < y1: |
| y1, y2 = y2, y1 |
| if scale_to_image: |
| x_scale = w / float(LAYOUT_INPUT_SIZE[1]) |
| y_scale = h / float(LAYOUT_INPUT_SIZE[0]) |
| bbox = _clip_bbox([x1 * x_scale, y1 * y_scale, x2 * x_scale, y2 * y_scale], w, h) |
| else: |
| bbox = _clip_bbox([x1, y1, x2, y2], w, h) |
| if bbox[2] - bbox[0] < 4 or bbox[3] - bbox[1] < 4: |
| return |
| region = {"type": region_type, "raw_type": raw_label, "score": score, "bbox": bbox} |
| if region_type == "figure": |
| x1c, y1c, x2c, y2c = bbox |
| crop = arr_rgb[max(0, y1c):max(0, y2c), max(0, x1c):max(0, x2c)] |
| if crop.size > 0: |
| region["img"] = crop |
| regions.append(region) |
|
|
| regions = [] |
| inputs = sess.get_inputs() |
| if len(inputs) >= 2: |
| |
| scale_factor = np.array([[LAYOUT_INPUT_SIZE[0] / max(h, 1), LAYOUT_INPUT_SIZE[1] / max(w, 1)]], dtype=np.float32) |
| outputs = sess.run(None, {inputs[0].name: image_input, inputs[1].name: scale_factor}) |
| dets, num = outputs |
| num = int(num[0]) |
| for row in dets[:num]: |
| _append_region(regions, int(row[0]), float(row[1]), row[2:].tolist(), scale_to_image=False) |
| else: |
| |
| outputs = sess.run(None, {inputs[0].name: image_input}) |
| if len(outputs) != 2: |
| raise RuntimeError(f"Unexpected layout outputs: {len(outputs)}") |
| out_names = [o.name for o in sess.get_outputs()] |
| out_map = {name: value for name, value in zip(out_names, outputs)} |
|
|
| if "Mul.109" in out_map and "Concat.9" in out_map: |
| boxes = out_map["Mul.109"][0] |
| cls_blob = out_map["Concat.9"] |
| else: |
| first, second = outputs |
| if first.ndim == 3 and first.shape[-1] == 4: |
| boxes, cls_blob = first[0], second |
| elif second.ndim == 3 and second.shape[-1] == 4: |
| boxes, cls_blob = second[0], first |
| else: |
| raise RuntimeError(f"Unexpected sim-cut output shapes: {first.shape}, {second.shape}") |
| if cls_blob.ndim != 3 or cls_blob.shape[0] != 1: |
| raise RuntimeError(f"Unexpected cls output shape: {cls_blob.shape}") |
| if cls_blob.shape[1] == len(LAYOUT_LABELS) and cls_blob.shape[2] == boxes.shape[0]: |
| cls_scores = cls_blob[0].transpose(1, 0) |
| elif cls_blob.shape[2] == len(LAYOUT_LABELS) and cls_blob.shape[1] == boxes.shape[0]: |
| cls_scores = cls_blob[0] |
| else: |
| raise RuntimeError(f"Unexpected cls/box alignment: cls={cls_blob.shape}, boxes={boxes.shape}") |
|
|
| best_cls = np.argmax(cls_scores, axis=1) |
| best_score = np.max(cls_scores, axis=1) |
| for i in range(boxes.shape[0]): |
| _append_region( |
| regions, |
| int(best_cls[i]), |
| float(best_score[i]), |
| boxes[i].tolist(), |
| scale_to_image=True, |
| ) |
| return _dedupe_layout_regions(regions) |
|
|
|
|
| def extract(img: Image.Image, scene_result: SceneResult, debug_prefix: Optional[str] = None) -> OCRResult: |
| arr_rgb = np.array(img.convert("RGB")) |
| threshold = scene_result.ocr_threshold |
|
|
| if scene_result.scene == Scene.SCREENSHOT: |
| blocks = _run_ocr_full(arr_rgb, threshold) |
| figure_images, figure_bboxes = [], [] |
| else: |
| structure_result = _run_structure(arr_rgb) |
| blocks = _run_ocr_full(arr_rgb, threshold) |
|
|
| figure_images, figure_bboxes = [], [] |
| for region in structure_result: |
| if region.get("type") == "figure": |
| sub = region.get("img") |
| bbox = region.get("bbox") |
| if sub is not None and bbox is not None: |
| figure_images.append(Image.fromarray(sub)) |
| figure_bboxes.append(bbox) |
|
|
| region_groups = {} |
| unknown_blocks = [] |
| for block in blocks: |
| rtype, ridx = _assign_region(block, structure_result) |
| if rtype != "unknown": |
| if ridx not in region_groups: |
| region_groups[ridx] = (rtype, []) |
| region_groups[ridx][1].append(block) |
| else: |
| unknown_blocks.append(block) |
|
|
| region_blocks = [] |
| for ridx, (rtype, grp) in sorted(region_groups.items()): |
| grp.sort(key=lambda b: (int((b.bbox[1] + b.bbox[3]) / 40), b.bbox[0])) |
| text = " ".join(b.text for b in grp if b.text) |
| score = sum(b.score for b in grp) / len(grp) |
| bbox = [ |
| min(b.bbox[0] for b in grp), |
| min(b.bbox[1] for b in grp), |
| max(b.bbox[2] for b in grp), |
| max(b.bbox[3] for b in grp), |
| ] |
| region_blocks.append((ridx, TextBlock(text=text, score=score, bbox=bbox, region_type=rtype))) |
|
|
| def _row_x(b: TextBlock): |
| return (int((b.bbox[1] + b.bbox[3]) / 40), b.bbox[0]) |
|
|
| region_blocks.sort(key=lambda t: (t[0], *_row_x(t[1]))) |
| sorted_region = [b for _, b in region_blocks] |
| unknown_blocks.sort(key=_row_x) |
|
|
| import bisect |
|
|
| rx_keys = [_row_x(b) for b in sorted_region] |
| for ub in unknown_blocks: |
| pos = bisect.bisect_left(rx_keys, _row_x(ub)) |
| sorted_region.insert(pos, ub) |
| rx_keys.insert(pos, _row_x(ub)) |
| blocks = sorted_region |
|
|
| blocks = [b for b in blocks if len(b.text.strip()) >= 2] |
| if not blocks: |
| text_state = ( |
| OCRTextState.MISSING_EXPECTED_TEXT |
| if scene_result.scene in {Scene.SCREENSHOT, Scene.DOCUMENT} |
| else OCRTextState.MISSING_UNCERTAIN |
| ) |
| return OCRResult( |
| blocks=[], |
| avg_score=0.0, |
| low_confidence=False, |
| text_state=text_state, |
| figure_images=figure_images, |
| figure_bboxes=figure_bboxes, |
| ) |
|
|
| avg_score = sum(b.score for b in blocks) / len(blocks) |
| low_confidence = avg_score < 0.65 and len(blocks) < 3 |
| text_state = OCRTextState.LOW_QUALITY if low_confidence else OCRTextState.OK |
| return OCRResult( |
| blocks=blocks, |
| avg_score=avg_score, |
| low_confidence=low_confidence, |
| text_state=text_state, |
| figure_images=figure_images, |
| figure_bboxes=figure_bboxes, |
| ) |
|
|