"""OCR 文本提取:PP-OCRv5 ONNX 推理 + PP-DocLayout-S sim-cut ONNX 版面分析""" from __future__ import annotations from dataclasses import dataclass, field from enum import Enum import math from pathlib import Path from threading import RLock from typing import List, Optional, Sequence, Tuple import cv2 import numpy as np from PIL import Image, ImageEnhance from ..preprocess.scene_classifier import Scene, SceneResult PERCEPTION_DIR = Path(__file__).resolve().parent PROJECT_ROOT = PERCEPTION_DIR.parents[1] AXMODEL_DIR = PROJECT_ROOT / "axmodel" DET_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "det_npu1.axmodel" CLS_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "cls_npu1.axmodel" REC_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "rec_npu1.axmodel" DICT_PATH = PERCEPTION_DIR / "dict" / "ppocrv5_dict.txt" LAYOUT_MODEL_PATH = AXMODEL_DIR / "ppstructurev3" / "ppstructure_npu1.axmodel" DET_INPUT_SIZE = 960 CLS_INPUT_SIZE = (80, 160) # h, w REC_INPUT_SIZE = (48, 320) # h, w LAYOUT_INPUT_SIZE = (480, 480) # h, w DET_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32) DET_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32) CLS_MEAN = np.array([127.5, 127.5, 127.5], dtype=np.float32) CLS_STD = np.array([127.5, 127.5, 127.5], dtype=np.float32) REC_MEAN = np.array([127.5, 127.5, 127.5], dtype=np.float32) REC_STD = np.array([127.5, 127.5, 127.5], dtype=np.float32) LAYOUT_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) LAYOUT_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) DET_BIN_THRESHOLD = 0.30 DET_MIN_COMPONENT_AREA = 16 DET_SCORE_THRESHOLD = 0.35 CLS_ROTATE_THRESHOLD = 0.90 LAYOUT_SCORE_THRESHOLD = 0.35 LAYOUT_DEDUP_IOU = 0.85 LAYOUT_LABELS = [ "paragraph_title", "image", "text", "number", "abstract", "content", "figure_title", "formula", "table", "table_title", "reference", "doc_title", "footnote", "header", "algorithm", "footer", "seal", "chart_title", "chart", "formula_number", "header_image", "footer_image", "aside_text", ] LAYOUT_TYPE_MAP = { "paragraph_title": "title", "doc_title": "title", "text": "text", "content": "text", "reference": "text", "abstract": "text", "number": "text", "aside_text": "text", "figure_title": "figure_caption", "table_title": "figure_caption", "chart_title": "figure_caption", "formula_number": "figure_caption", "image": "figure", "chart": "figure", "header_image": "figure", "footer_image": "figure", "table": "table", "formula": "figure", "header": "unknown", "footer": "unknown", "seal": "unknown", "algorithm": "text", "footnote": "text", } LAYOUT_ENABLED_LABELS = { "paragraph_title", "image", "text", "number", "abstract", "content", "figure_title", "formula", "table", "table_title", "reference", "doc_title", "footnote", "algorithm", "chart_title", "chart", } _ort = None _runtime_lock = RLock() _det_sess = None _cls_sess = None _rec_sess = None _layout_sess = None _char_list = None _pyclipper = None def _get_ort(): global _ort if _ort is None: try: # import onnxruntime as ort # type: ignore import axengine as ort # type: ignore except ImportError as exc: # pragma: no cover - environment issue raise RuntimeError( "onnxruntime is required for PP-OCRv5 ONNX inference. " "Please install it in the runtime environment." ) from exc _ort = ort return _ort def _get_session(path: Path): with _runtime_lock: ort = _get_ort() return ort.InferenceSession(str(path), providers=["AxEngineExecutionProvider"]) def _get_det_sess(): global _det_sess if _det_sess is None: _det_sess = _get_session(DET_MODEL_PATH) return _det_sess def _get_cls_sess(): global _cls_sess if _cls_sess is None: _cls_sess = _get_session(CLS_MODEL_PATH) return _cls_sess def _get_rec_sess(): global _rec_sess if _rec_sess is None: _rec_sess = _get_session(REC_MODEL_PATH) return _rec_sess def _get_layout_sess(): global _layout_sess if _layout_sess is None: _layout_sess = _get_session(LAYOUT_MODEL_PATH) return _layout_sess def _get_char_list() -> List[str]: global _char_list if _char_list is None: chars = DICT_PATH.read_text(encoding="utf-8").splitlines() # PP-OCRv5 rec output class count = blank + dict + ascii space. _char_list = ["blank"] + chars + [" "] return _char_list def _get_pyclipper(): global _pyclipper if _pyclipper is None: try: import pyclipper # type: ignore except ImportError: _pyclipper = False else: _pyclipper = pyclipper return _pyclipper @dataclass class TextBlock: text: str score: float bbox: List[int] # [x1, y1, x2, y2] region_type: str = "unknown" # text / title / figure_caption / unknown class OCRTextState(str, Enum): OK = "ok" LOW_QUALITY = "low_quality" MISSING_EXPECTED_TEXT = "missing_expected_text" MISSING_UNCERTAIN = "missing_uncertain" @dataclass class OCRResult: blocks: List[TextBlock] = field(default_factory=list) avg_score: float = 0.0 low_confidence: bool = False text_state: OCRTextState = OCRTextState.OK figure_images: List[Image.Image] = field(default_factory=list) figure_bboxes: List[List[int]] = field(default_factory=list) def _resize_no_pad(arr_rgb: np.ndarray, target_h: int, target_w: int, mean: np.ndarray, std: np.ndarray) -> np.ndarray: resized = np.array(Image.fromarray(arr_rgb).resize((target_w, target_h), Image.BICUBIC)).astype(np.float32).transpose(2, 0, 1)[np.newaxis] # resized = resized / 255.0 # resized = ((resized - mean) / std).transpose(2, 0, 1)[np.newaxis] return resized def _get_mini_boxes(contour: np.ndarray) -> Tuple[List[np.ndarray], float]: bounding_box = cv2.minAreaRect(contour) points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) if points[1][1] > points[0][1]: index_1, index_4 = 0, 1 else: index_1, index_4 = 1, 0 if points[3][1] > points[2][1]: index_2, index_3 = 2, 3 else: index_2, index_3 = 3, 2 box = [points[index_1], points[index_2], points[index_3], points[index_4]] return box, min(bounding_box[1]) def _box_score_fast(bitmap: np.ndarray, box: np.ndarray) -> float: h, w = bitmap.shape[:2] scored_box = box.copy() xmin = np.clip(np.floor(scored_box[:, 0].min()).astype("int32"), 0, w - 1) xmax = np.clip(np.ceil(scored_box[:, 0].max()).astype("int32"), 0, w - 1) ymin = np.clip(np.floor(scored_box[:, 1].min()).astype("int32"), 0, h - 1) ymax = np.clip(np.ceil(scored_box[:, 1].max()).astype("int32"), 0, h - 1) mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) scored_box[:, 0] = scored_box[:, 0] - xmin scored_box[:, 1] = scored_box[:, 1] - ymin cv2.fillPoly(mask, scored_box.reshape(1, -1, 2).astype("int32"), 1) return float(cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]) def _unclip(box: np.ndarray, unclip_ratio: float) -> np.ndarray: area = float(abs(cv2.contourArea(box.astype(np.float32)))) length = float(cv2.arcLength(box.astype(np.float32), True)) if area <= 1e-6 or length <= 1e-6: return np.array([box], dtype=np.float32) pyclipper = _get_pyclipper() if not pyclipper: return np.array([box], dtype=np.float32) distance = area * unclip_ratio / length offset = pyclipper.PyclipperOffset() offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) expanded = offset.Execute(distance) if not expanded: return np.array([box], dtype=np.float32) return np.array(expanded) def _boxes_from_bitmap(pred: np.ndarray, bitmap: np.ndarray, dest_width: int, dest_height: int) -> Tuple[np.ndarray, List[float]]: box_thresh = 0.6 max_candidates = 1000 unclip_ratio = 1.5 min_size = 3 height, width = bitmap.shape outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) if len(outs) == 3: contours = outs[1] else: contours = outs[0] num_contours = min(len(contours), max_candidates) boxes: List[np.ndarray] = [] scores: List[float] = [] for index in range(num_contours): contour = contours[index] points, sside = _get_mini_boxes(contour) if sside < min_size: continue points = np.array(points) score = _box_score_fast(pred, points.reshape(-1, 2)) if score < box_thresh: continue box = _unclip(points, unclip_ratio).reshape(-1, 1, 2) box, sside = _get_mini_boxes(box) if sside < min_size + 2: continue box = np.array(box) box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) boxes.append(box.astype("int32")) scores.append(float(score)) if not boxes: return np.zeros((0, 4, 2), dtype="int32"), scores return np.array(boxes, dtype="int32"), scores def _order_points_clockwise(points: np.ndarray) -> np.ndarray: rect = np.zeros((4, 2), dtype="float32") point_sums = points.sum(axis=1) rect[0] = points[np.argmin(point_sums)] rect[2] = points[np.argmax(point_sums)] remaining = np.delete(points, (np.argmin(point_sums), np.argmax(point_sums)), axis=0) diffs = np.diff(np.array(remaining), axis=1) rect[1] = remaining[np.argmin(diffs)] rect[3] = remaining[np.argmax(diffs)] return rect def _clip_det_res(points: np.ndarray, img_height: int, img_width: int) -> np.ndarray: for index in range(points.shape[0]): points[index, 0] = int(min(max(points[index, 0], 0), img_width - 1)) points[index, 1] = int(min(max(points[index, 1], 0), img_height - 1)) return points def _filter_tag_det_res(dt_boxes: np.ndarray, image_shape: Sequence[int]) -> np.ndarray: img_height, img_width = image_shape[0:2] filtered_boxes = [] for box in dt_boxes: current_box = np.array(box) if isinstance(box, list) else box.copy() current_box = _order_points_clockwise(current_box) current_box = _clip_det_res(current_box, img_height, img_width) rect_width = int(np.linalg.norm(current_box[0] - current_box[1])) rect_height = int(np.linalg.norm(current_box[0] - current_box[3])) if rect_width <= 3 or rect_height <= 3: continue filtered_boxes.append(current_box) if not filtered_boxes: return np.zeros((0, 4, 2), dtype=np.float32) return np.array(filtered_boxes) def _sorted_boxes(dt_boxes: np.ndarray) -> List[np.ndarray]: num_boxes = dt_boxes.shape[0] ordered = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) boxes = list(ordered) for i in range(num_boxes - 1): for j in range(i, -1, -1): if abs(boxes[j + 1][0][1] - boxes[j][0][1]) < 10 and boxes[j + 1][0][0] < boxes[j][0][0]: boxes[j], boxes[j + 1] = boxes[j + 1], boxes[j] else: break return boxes def _det_postprocess(outs_dict: dict, shape_list: np.ndarray) -> List[dict]: pred = outs_dict["maps"] pred = pred[:, 0, :, :] segmentation = pred > DET_BIN_THRESHOLD boxes_batch = [] for batch_index in range(pred.shape[0]): src_h, src_w = int(shape_list[batch_index][0]), int(shape_list[batch_index][1]) mask = segmentation[batch_index] boxes, _ = _boxes_from_bitmap(pred[batch_index], mask, src_w, src_h) boxes_batch.append({"points": boxes}) return boxes_batch def _detect_boxes(arr_bgr: np.ndarray) -> List[np.ndarray]: orig_h, orig_w = arr_bgr.shape[:2] det_input = cv2.resize(arr_bgr, (DET_INPUT_SIZE, DET_INPUT_SIZE)).astype(np.float32) det_input = det_input.transpose(2, 0, 1)[np.newaxis] shape_list = np.array( [[orig_h, orig_w, DET_INPUT_SIZE / max(orig_h, 1), DET_INPUT_SIZE / max(orig_w, 1)]], dtype=np.float32, ) sess = _get_det_sess() det_out = sess.run(None, {sess.get_inputs()[0].name: det_input}) post_result = _det_postprocess({"maps": det_out[0]}, shape_list) dt_boxes = _filter_tag_det_res(post_result[0]["points"], arr_bgr.shape) if dt_boxes.size == 0: return [] return _sorted_boxes(dt_boxes) def _get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray: if len(points) != 4: raise ValueError("shape of points must be 4*2") img_crop_width = int(max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]))) img_crop_height = int(max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]))) pts_std = np.float32( [ [0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height], ] ) transform = cv2.getPerspectiveTransform(points.astype(np.float32), pts_std) dst_img = cv2.warpPerspective( img, transform, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC, ) dst_img_height, dst_img_width = dst_img.shape[:2] if dst_img_width == 0 or dst_img_height == 0: return dst_img if dst_img_height * 1.0 / dst_img_width >= 1.5: dst_img = np.rot90(dst_img) return dst_img def _resize_norm_img(img: np.ndarray, shape: Tuple[int, int, int]) -> np.ndarray: img_h, img_w = img.shape[:2] img_c, target_h, target_w = shape ratio = img_w / float(max(img_h, 1)) if math.ceil(target_h * ratio) > target_w: resized_w = target_w else: resized_w = int(math.ceil(target_h * ratio)) resized_image = cv2.resize(img, (resized_w, target_h)).astype(np.float32) resized_image = resized_image.transpose((2, 0, 1)) padding_im = np.zeros((img_c, target_h, target_w), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im def _classify_orientation(crop_bgr: np.ndarray) -> np.ndarray: if crop_bgr.size == 0: return crop_bgr cls_input = _resize_norm_img(crop_bgr, (3, CLS_INPUT_SIZE[0], CLS_INPUT_SIZE[1]))[np.newaxis].copy() sess = _get_cls_sess() probs = sess.run(None, {sess.get_inputs()[0].name: cls_input})[0] pred_idx = int(np.argmax(probs[0])) pred_score = float(probs[0, pred_idx]) if pred_idx == 1 and pred_score > CLS_ROTATE_THRESHOLD: return cv2.rotate(crop_bgr, cv2.ROTATE_180) return crop_bgr def _preprocess_rec(crop_bgr: np.ndarray) -> np.ndarray: if crop_bgr.size == 0: return np.zeros((1, 3, REC_INPUT_SIZE[0], REC_INPUT_SIZE[1]), dtype=np.float32) return _resize_norm_img(crop_bgr, (3, REC_INPUT_SIZE[0], REC_INPUT_SIZE[1]))[np.newaxis].copy() def _decode_text_indices( chars: List[str], text_index: np.ndarray, text_prob: np.ndarray, is_remove_duplicate: bool = False, ) -> List[Tuple[str, float]]: result_list: List[Tuple[str, float]] = [] ignored_tokens = [0] batch_size = len(text_index) for batch_idx in range(batch_size): selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token char_list = [chars[text_id] for text_id in text_index[batch_idx][selection] if text_id < len(chars)] conf_list = text_prob[batch_idx][selection] if len(conf_list) == 0: conf_list = np.array([0.0], dtype=np.float32) result_list.append(("".join(char_list), float(np.mean(conf_list).tolist()))) return result_list def _decode_rec(logits: np.ndarray) -> Tuple[str, float]: chars = _get_char_list() preds = logits[-1] if isinstance(logits, (tuple, list)) else logits preds = preds.astype(np.float32) preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) return _decode_text_indices(chars, preds_idx, preds_prob, is_remove_duplicate=True)[0] def _recognize_crop(crop_bgr: np.ndarray) -> Tuple[str, float]: oriented = _classify_orientation(crop_bgr) rec_input = _preprocess_rec(oriented) sess = _get_rec_sess() logits = sess.run(None, {sess.get_inputs()[0].name: rec_input})[0] return _decode_rec(logits) def _enhance_crop(crop_bgr: np.ndarray) -> np.ndarray: img = Image.fromarray(crop_bgr[:, :, ::-1]) img = ImageEnhance.Contrast(img).enhance(1.5) return np.array(img)[:, :, ::-1] def _quad_to_bbox(points: np.ndarray, width: int, height: int) -> List[int]: min_x = max(0, int(np.floor(np.min(points[:, 0])))) min_y = max(0, int(np.floor(np.min(points[:, 1])))) max_x = min(width, int(np.ceil(np.max(points[:, 0])))) max_y = min(height, int(np.ceil(np.max(points[:, 1])))) return [min_x, min_y, max_x, max_y] def _iou(a: List[int], b: List[int]) -> float: ax1, ay1, ax2, ay2 = a bx1, by1, bx2, by2 = b ix1, iy1 = max(ax1, bx1), max(ay1, by1) ix2, iy2 = min(ax2, bx2), min(ay2, by2) inter = max(0, ix2 - ix1) * max(0, iy2 - iy1) if inter == 0: return 0.0 area_a = (ax2 - ax1) * (ay2 - ay1) area_b = (bx2 - bx1) * (by2 - by1) return inter / min(area_a, area_b) def _clip_bbox(bbox: Sequence[float], width: int, height: int) -> List[int]: x1, y1, x2, y2 = bbox clipped = [ int(round(max(0.0, min(float(width), x1)))), int(round(max(0.0, min(float(height), y1)))), int(round(max(0.0, min(float(width), x2)))), int(round(max(0.0, min(float(height), y2)))), ] return clipped def _dedupe_layout_regions(regions: List[dict]) -> List[dict]: kept: List[dict] = [] for region in sorted(regions, key=lambda item: item["score"], reverse=True): duplicate = False for existing in kept: same_type = existing.get("type") == region.get("type") same_label = existing.get("raw_type") == region.get("raw_type") if not same_type: continue if region.get("type") != "figure" and not same_label: continue if _iou(existing["bbox"], region["bbox"]) >= LAYOUT_DEDUP_IOU: duplicate = True break if not duplicate: kept.append(region) kept.sort(key=lambda item: (item["bbox"][1], item["bbox"][0])) return kept def _run_ocr_full(arr_rgb: np.ndarray, threshold: float) -> List[TextBlock]: blocks = [] arr_bgr = arr_rgb[:, :, ::-1] height, width = arr_rgb.shape[:2] for dt_box in _detect_boxes(arr_bgr): crop = _get_rotate_crop_image(arr_bgr, np.asarray(dt_box, dtype=np.float32)) if crop.size == 0: continue text, score = _recognize_crop(crop) if score < threshold: retry_text, retry_score = _recognize_crop(_enhance_crop(crop)) if retry_score > score: text, score = retry_text, retry_score bbox = _quad_to_bbox(np.asarray(dt_box, dtype=np.float32), width, height) blocks.append(TextBlock(text=text, score=score, bbox=bbox)) return blocks def _assign_region(block: TextBlock, regions: list) -> tuple: """返回 (region_type, region_index)""" best_type, best_iou, best_idx = "unknown", 0.0, len(regions) for i, region in enumerate(regions): bbox = region.get("bbox") if bbox is None: continue score = _iou(block.bbox, bbox) if score > best_iou: best_iou = score best_type = region.get("type", "unknown") best_idx = i if best_iou >= 0.9: return best_type, best_idx return (best_type if best_iou > 0.1 else "unknown"), best_idx def _run_structure(arr_rgb: np.ndarray): sess = _get_layout_sess() h, w = arr_rgb.shape[:2] image_input = _resize_no_pad(arr_rgb, LAYOUT_INPUT_SIZE[0], LAYOUT_INPUT_SIZE[1], LAYOUT_MEAN, LAYOUT_STD) def _append_region(regions: List[dict], cls_id: int, score: float, bbox_values: Sequence[float], scale_to_image: bool): if score < LAYOUT_SCORE_THRESHOLD: return raw_label = LAYOUT_LABELS[cls_id] if 0 <= cls_id < len(LAYOUT_LABELS) else str(cls_id) if raw_label not in LAYOUT_ENABLED_LABELS: return region_type = LAYOUT_TYPE_MAP.get(raw_label, "unknown") x1, y1, x2, y2 = map(float, bbox_values) if x2 < x1: x1, x2 = x2, x1 if y2 < y1: y1, y2 = y2, y1 if scale_to_image: x_scale = w / float(LAYOUT_INPUT_SIZE[1]) y_scale = h / float(LAYOUT_INPUT_SIZE[0]) bbox = _clip_bbox([x1 * x_scale, y1 * y_scale, x2 * x_scale, y2 * y_scale], w, h) else: bbox = _clip_bbox([x1, y1, x2, y2], w, h) if bbox[2] - bbox[0] < 4 or bbox[3] - bbox[1] < 4: return region = {"type": region_type, "raw_type": raw_label, "score": score, "bbox": bbox} if region_type == "figure": x1c, y1c, x2c, y2c = bbox crop = arr_rgb[max(0, y1c):max(0, y2c), max(0, x1c):max(0, x2c)] if crop.size > 0: region["img"] = crop regions.append(region) regions = [] inputs = sess.get_inputs() if len(inputs) >= 2: # 兼容原版 PP-DocLayout-S:输出为 [cls_id, score, x1, y1, x2, y2] + num scale_factor = np.array([[LAYOUT_INPUT_SIZE[0] / max(h, 1), LAYOUT_INPUT_SIZE[1] / max(w, 1)]], dtype=np.float32) outputs = sess.run(None, {inputs[0].name: image_input, inputs[1].name: scale_factor}) dets, num = outputs num = int(num[0]) for row in dets[:num]: _append_region(regions, int(row[0]), float(row[1]), row[2:].tolist(), scale_to_image=False) else: # 兼容 sim-cut 导出:输出为 boxes[N,4] + cls[C,N](均在 480x480 输入坐标系) outputs = sess.run(None, {inputs[0].name: image_input}) if len(outputs) != 2: raise RuntimeError(f"Unexpected layout outputs: {len(outputs)}") out_names = [o.name for o in sess.get_outputs()] out_map = {name: value for name, value in zip(out_names, outputs)} if "Mul.109" in out_map and "Concat.9" in out_map: boxes = out_map["Mul.109"][0] cls_blob = out_map["Concat.9"] else: first, second = outputs if first.ndim == 3 and first.shape[-1] == 4: boxes, cls_blob = first[0], second elif second.ndim == 3 and second.shape[-1] == 4: boxes, cls_blob = second[0], first else: raise RuntimeError(f"Unexpected sim-cut output shapes: {first.shape}, {second.shape}") if cls_blob.ndim != 3 or cls_blob.shape[0] != 1: raise RuntimeError(f"Unexpected cls output shape: {cls_blob.shape}") if cls_blob.shape[1] == len(LAYOUT_LABELS) and cls_blob.shape[2] == boxes.shape[0]: cls_scores = cls_blob[0].transpose(1, 0) # [N, C] elif cls_blob.shape[2] == len(LAYOUT_LABELS) and cls_blob.shape[1] == boxes.shape[0]: cls_scores = cls_blob[0] # [N, C] else: raise RuntimeError(f"Unexpected cls/box alignment: cls={cls_blob.shape}, boxes={boxes.shape}") best_cls = np.argmax(cls_scores, axis=1) best_score = np.max(cls_scores, axis=1) for i in range(boxes.shape[0]): _append_region( regions, int(best_cls[i]), float(best_score[i]), boxes[i].tolist(), scale_to_image=True, ) return _dedupe_layout_regions(regions) def extract(img: Image.Image, scene_result: SceneResult, debug_prefix: Optional[str] = None) -> OCRResult: arr_rgb = np.array(img.convert("RGB")) threshold = scene_result.ocr_threshold if scene_result.scene == Scene.SCREENSHOT: blocks = _run_ocr_full(arr_rgb, threshold) figure_images, figure_bboxes = [], [] else: structure_result = _run_structure(arr_rgb) blocks = _run_ocr_full(arr_rgb, threshold) figure_images, figure_bboxes = [], [] for region in structure_result: if region.get("type") == "figure": sub = region.get("img") bbox = region.get("bbox") if sub is not None and bbox is not None: figure_images.append(Image.fromarray(sub)) figure_bboxes.append(bbox) region_groups = {} unknown_blocks = [] for block in blocks: rtype, ridx = _assign_region(block, structure_result) if rtype != "unknown": if ridx not in region_groups: region_groups[ridx] = (rtype, []) region_groups[ridx][1].append(block) else: unknown_blocks.append(block) region_blocks = [] for ridx, (rtype, grp) in sorted(region_groups.items()): grp.sort(key=lambda b: (int((b.bbox[1] + b.bbox[3]) / 40), b.bbox[0])) text = " ".join(b.text for b in grp if b.text) score = sum(b.score for b in grp) / len(grp) bbox = [ min(b.bbox[0] for b in grp), min(b.bbox[1] for b in grp), max(b.bbox[2] for b in grp), max(b.bbox[3] for b in grp), ] region_blocks.append((ridx, TextBlock(text=text, score=score, bbox=bbox, region_type=rtype))) def _row_x(b: TextBlock): return (int((b.bbox[1] + b.bbox[3]) / 40), b.bbox[0]) region_blocks.sort(key=lambda t: (t[0], *_row_x(t[1]))) sorted_region = [b for _, b in region_blocks] unknown_blocks.sort(key=_row_x) import bisect rx_keys = [_row_x(b) for b in sorted_region] for ub in unknown_blocks: pos = bisect.bisect_left(rx_keys, _row_x(ub)) sorted_region.insert(pos, ub) rx_keys.insert(pos, _row_x(ub)) blocks = sorted_region blocks = [b for b in blocks if len(b.text.strip()) >= 2] if not blocks: text_state = ( OCRTextState.MISSING_EXPECTED_TEXT if scene_result.scene in {Scene.SCREENSHOT, Scene.DOCUMENT} else OCRTextState.MISSING_UNCERTAIN ) return OCRResult( blocks=[], avg_score=0.0, low_confidence=False, text_state=text_state, figure_images=figure_images, figure_bboxes=figure_bboxes, ) avg_score = sum(b.score for b in blocks) / len(blocks) low_confidence = avg_score < 0.65 and len(blocks) < 3 text_state = OCRTextState.LOW_QUALITY if low_confidence else OCRTextState.OK return OCRResult( blocks=blocks, avg_score=avg_score, low_confidence=low_confidence, text_state=text_state, figure_images=figure_images, figure_bboxes=figure_bboxes, )