pp-nsfw_Inspector / src /perception /ocr_extractor.py
philcuriosity1024's picture
Upload folder using huggingface_hub
670cf0c verified
Raw
History Blame Contribute Delete
27.5 kB
"""OCR 文本提取:PP-OCRv5 ONNX 推理 + PP-DocLayout-S sim-cut ONNX 版面分析"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
import math
from pathlib import Path
from threading import RLock
from typing import List, Optional, Sequence, Tuple
import cv2
import numpy as np
from PIL import Image, ImageEnhance
from ..preprocess.scene_classifier import Scene, SceneResult
PERCEPTION_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = PERCEPTION_DIR.parents[1]
AXMODEL_DIR = PROJECT_ROOT / "axmodel"
DET_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "det_npu1.axmodel"
CLS_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "cls_npu1.axmodel"
REC_MODEL_PATH = AXMODEL_DIR / "ppocrv5" / "rec_npu1.axmodel"
DICT_PATH = PERCEPTION_DIR / "dict" / "ppocrv5_dict.txt"
LAYOUT_MODEL_PATH = AXMODEL_DIR / "ppstructurev3" / "ppstructure_npu1.axmodel"
DET_INPUT_SIZE = 960
CLS_INPUT_SIZE = (80, 160) # h, w
REC_INPUT_SIZE = (48, 320) # h, w
LAYOUT_INPUT_SIZE = (480, 480) # h, w
DET_MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
DET_STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)
CLS_MEAN = np.array([127.5, 127.5, 127.5], dtype=np.float32)
CLS_STD = np.array([127.5, 127.5, 127.5], dtype=np.float32)
REC_MEAN = np.array([127.5, 127.5, 127.5], dtype=np.float32)
REC_STD = np.array([127.5, 127.5, 127.5], dtype=np.float32)
LAYOUT_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
LAYOUT_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
DET_BIN_THRESHOLD = 0.30
DET_MIN_COMPONENT_AREA = 16
DET_SCORE_THRESHOLD = 0.35
CLS_ROTATE_THRESHOLD = 0.90
LAYOUT_SCORE_THRESHOLD = 0.35
LAYOUT_DEDUP_IOU = 0.85
LAYOUT_LABELS = [
"paragraph_title",
"image",
"text",
"number",
"abstract",
"content",
"figure_title",
"formula",
"table",
"table_title",
"reference",
"doc_title",
"footnote",
"header",
"algorithm",
"footer",
"seal",
"chart_title",
"chart",
"formula_number",
"header_image",
"footer_image",
"aside_text",
]
LAYOUT_TYPE_MAP = {
"paragraph_title": "title",
"doc_title": "title",
"text": "text",
"content": "text",
"reference": "text",
"abstract": "text",
"number": "text",
"aside_text": "text",
"figure_title": "figure_caption",
"table_title": "figure_caption",
"chart_title": "figure_caption",
"formula_number": "figure_caption",
"image": "figure",
"chart": "figure",
"header_image": "figure",
"footer_image": "figure",
"table": "table",
"formula": "figure",
"header": "unknown",
"footer": "unknown",
"seal": "unknown",
"algorithm": "text",
"footnote": "text",
}
LAYOUT_ENABLED_LABELS = {
"paragraph_title",
"image",
"text",
"number",
"abstract",
"content",
"figure_title",
"formula",
"table",
"table_title",
"reference",
"doc_title",
"footnote",
"algorithm",
"chart_title",
"chart",
}
_ort = None
_runtime_lock = RLock()
_det_sess = None
_cls_sess = None
_rec_sess = None
_layout_sess = None
_char_list = None
_pyclipper = None
def _get_ort():
global _ort
if _ort is None:
try:
# import onnxruntime as ort # type: ignore
import axengine as ort # type: ignore
except ImportError as exc: # pragma: no cover - environment issue
raise RuntimeError(
"onnxruntime is required for PP-OCRv5 ONNX inference. "
"Please install it in the runtime environment."
) from exc
_ort = ort
return _ort
def _get_session(path: Path):
with _runtime_lock:
ort = _get_ort()
return ort.InferenceSession(str(path), providers=["AxEngineExecutionProvider"])
def _get_det_sess():
global _det_sess
if _det_sess is None:
_det_sess = _get_session(DET_MODEL_PATH)
return _det_sess
def _get_cls_sess():
global _cls_sess
if _cls_sess is None:
_cls_sess = _get_session(CLS_MODEL_PATH)
return _cls_sess
def _get_rec_sess():
global _rec_sess
if _rec_sess is None:
_rec_sess = _get_session(REC_MODEL_PATH)
return _rec_sess
def _get_layout_sess():
global _layout_sess
if _layout_sess is None:
_layout_sess = _get_session(LAYOUT_MODEL_PATH)
return _layout_sess
def _get_char_list() -> List[str]:
global _char_list
if _char_list is None:
chars = DICT_PATH.read_text(encoding="utf-8").splitlines()
# PP-OCRv5 rec output class count = blank + dict + ascii space.
_char_list = ["blank"] + chars + [" "]
return _char_list
def _get_pyclipper():
global _pyclipper
if _pyclipper is None:
try:
import pyclipper # type: ignore
except ImportError:
_pyclipper = False
else:
_pyclipper = pyclipper
return _pyclipper
@dataclass
class TextBlock:
text: str
score: float
bbox: List[int] # [x1, y1, x2, y2]
region_type: str = "unknown" # text / title / figure_caption / unknown
class OCRTextState(str, Enum):
OK = "ok"
LOW_QUALITY = "low_quality"
MISSING_EXPECTED_TEXT = "missing_expected_text"
MISSING_UNCERTAIN = "missing_uncertain"
@dataclass
class OCRResult:
blocks: List[TextBlock] = field(default_factory=list)
avg_score: float = 0.0
low_confidence: bool = False
text_state: OCRTextState = OCRTextState.OK
figure_images: List[Image.Image] = field(default_factory=list)
figure_bboxes: List[List[int]] = field(default_factory=list)
def _resize_no_pad(arr_rgb: np.ndarray, target_h: int, target_w: int, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
resized = np.array(Image.fromarray(arr_rgb).resize((target_w, target_h), Image.BICUBIC)).astype(np.float32).transpose(2, 0, 1)[np.newaxis]
# resized = resized / 255.0
# resized = ((resized - mean) / std).transpose(2, 0, 1)[np.newaxis]
return resized
def _get_mini_boxes(contour: np.ndarray) -> Tuple[List[np.ndarray], float]:
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
if points[1][1] > points[0][1]:
index_1, index_4 = 0, 1
else:
index_1, index_4 = 1, 0
if points[3][1] > points[2][1]:
index_2, index_3 = 2, 3
else:
index_2, index_3 = 3, 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def _box_score_fast(bitmap: np.ndarray, box: np.ndarray) -> float:
h, w = bitmap.shape[:2]
scored_box = box.copy()
xmin = np.clip(np.floor(scored_box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(scored_box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(scored_box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(scored_box[:, 1].max()).astype("int32"), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
scored_box[:, 0] = scored_box[:, 0] - xmin
scored_box[:, 1] = scored_box[:, 1] - ymin
cv2.fillPoly(mask, scored_box.reshape(1, -1, 2).astype("int32"), 1)
return float(cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0])
def _unclip(box: np.ndarray, unclip_ratio: float) -> np.ndarray:
area = float(abs(cv2.contourArea(box.astype(np.float32))))
length = float(cv2.arcLength(box.astype(np.float32), True))
if area <= 1e-6 or length <= 1e-6:
return np.array([box], dtype=np.float32)
pyclipper = _get_pyclipper()
if not pyclipper:
return np.array([box], dtype=np.float32)
distance = area * unclip_ratio / length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = offset.Execute(distance)
if not expanded:
return np.array([box], dtype=np.float32)
return np.array(expanded)
def _boxes_from_bitmap(pred: np.ndarray, bitmap: np.ndarray, dest_width: int, dest_height: int) -> Tuple[np.ndarray, List[float]]:
box_thresh = 0.6
max_candidates = 1000
unclip_ratio = 1.5
min_size = 3
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
contours = outs[1]
else:
contours = outs[0]
num_contours = min(len(contours), max_candidates)
boxes: List[np.ndarray] = []
scores: List[float] = []
for index in range(num_contours):
contour = contours[index]
points, sside = _get_mini_boxes(contour)
if sside < min_size:
continue
points = np.array(points)
score = _box_score_fast(pred, points.reshape(-1, 2))
if score < box_thresh:
continue
box = _unclip(points, unclip_ratio).reshape(-1, 1, 2)
box, sside = _get_mini_boxes(box)
if sside < min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype("int32"))
scores.append(float(score))
if not boxes:
return np.zeros((0, 4, 2), dtype="int32"), scores
return np.array(boxes, dtype="int32"), scores
def _order_points_clockwise(points: np.ndarray) -> np.ndarray:
rect = np.zeros((4, 2), dtype="float32")
point_sums = points.sum(axis=1)
rect[0] = points[np.argmin(point_sums)]
rect[2] = points[np.argmax(point_sums)]
remaining = np.delete(points, (np.argmin(point_sums), np.argmax(point_sums)), axis=0)
diffs = np.diff(np.array(remaining), axis=1)
rect[1] = remaining[np.argmin(diffs)]
rect[3] = remaining[np.argmax(diffs)]
return rect
def _clip_det_res(points: np.ndarray, img_height: int, img_width: int) -> np.ndarray:
for index in range(points.shape[0]):
points[index, 0] = int(min(max(points[index, 0], 0), img_width - 1))
points[index, 1] = int(min(max(points[index, 1], 0), img_height - 1))
return points
def _filter_tag_det_res(dt_boxes: np.ndarray, image_shape: Sequence[int]) -> np.ndarray:
img_height, img_width = image_shape[0:2]
filtered_boxes = []
for box in dt_boxes:
current_box = np.array(box) if isinstance(box, list) else box.copy()
current_box = _order_points_clockwise(current_box)
current_box = _clip_det_res(current_box, img_height, img_width)
rect_width = int(np.linalg.norm(current_box[0] - current_box[1]))
rect_height = int(np.linalg.norm(current_box[0] - current_box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
filtered_boxes.append(current_box)
if not filtered_boxes:
return np.zeros((0, 4, 2), dtype=np.float32)
return np.array(filtered_boxes)
def _sorted_boxes(dt_boxes: np.ndarray) -> List[np.ndarray]:
num_boxes = dt_boxes.shape[0]
ordered = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
boxes = list(ordered)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(boxes[j + 1][0][1] - boxes[j][0][1]) < 10 and boxes[j + 1][0][0] < boxes[j][0][0]:
boxes[j], boxes[j + 1] = boxes[j + 1], boxes[j]
else:
break
return boxes
def _det_postprocess(outs_dict: dict, shape_list: np.ndarray) -> List[dict]:
pred = outs_dict["maps"]
pred = pred[:, 0, :, :]
segmentation = pred > DET_BIN_THRESHOLD
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w = int(shape_list[batch_index][0]), int(shape_list[batch_index][1])
mask = segmentation[batch_index]
boxes, _ = _boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
boxes_batch.append({"points": boxes})
return boxes_batch
def _detect_boxes(arr_bgr: np.ndarray) -> List[np.ndarray]:
orig_h, orig_w = arr_bgr.shape[:2]
det_input = cv2.resize(arr_bgr, (DET_INPUT_SIZE, DET_INPUT_SIZE)).astype(np.float32)
det_input = det_input.transpose(2, 0, 1)[np.newaxis]
shape_list = np.array(
[[orig_h, orig_w, DET_INPUT_SIZE / max(orig_h, 1), DET_INPUT_SIZE / max(orig_w, 1)]],
dtype=np.float32,
)
sess = _get_det_sess()
det_out = sess.run(None, {sess.get_inputs()[0].name: det_input})
post_result = _det_postprocess({"maps": det_out[0]}, shape_list)
dt_boxes = _filter_tag_det_res(post_result[0]["points"], arr_bgr.shape)
if dt_boxes.size == 0:
return []
return _sorted_boxes(dt_boxes)
def _get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray:
if len(points) != 4:
raise ValueError("shape of points must be 4*2")
img_crop_width = int(max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])))
img_crop_height = int(max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])))
pts_std = np.float32(
[
[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height],
]
)
transform = cv2.getPerspectiveTransform(points.astype(np.float32), pts_std)
dst_img = cv2.warpPerspective(
img,
transform,
(img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC,
)
dst_img_height, dst_img_width = dst_img.shape[:2]
if dst_img_width == 0 or dst_img_height == 0:
return dst_img
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def _resize_norm_img(img: np.ndarray, shape: Tuple[int, int, int]) -> np.ndarray:
img_h, img_w = img.shape[:2]
img_c, target_h, target_w = shape
ratio = img_w / float(max(img_h, 1))
if math.ceil(target_h * ratio) > target_w:
resized_w = target_w
else:
resized_w = int(math.ceil(target_h * ratio))
resized_image = cv2.resize(img, (resized_w, target_h)).astype(np.float32)
resized_image = resized_image.transpose((2, 0, 1))
padding_im = np.zeros((img_c, target_h, target_w), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def _classify_orientation(crop_bgr: np.ndarray) -> np.ndarray:
if crop_bgr.size == 0:
return crop_bgr
cls_input = _resize_norm_img(crop_bgr, (3, CLS_INPUT_SIZE[0], CLS_INPUT_SIZE[1]))[np.newaxis].copy()
sess = _get_cls_sess()
probs = sess.run(None, {sess.get_inputs()[0].name: cls_input})[0]
pred_idx = int(np.argmax(probs[0]))
pred_score = float(probs[0, pred_idx])
if pred_idx == 1 and pred_score > CLS_ROTATE_THRESHOLD:
return cv2.rotate(crop_bgr, cv2.ROTATE_180)
return crop_bgr
def _preprocess_rec(crop_bgr: np.ndarray) -> np.ndarray:
if crop_bgr.size == 0:
return np.zeros((1, 3, REC_INPUT_SIZE[0], REC_INPUT_SIZE[1]), dtype=np.float32)
return _resize_norm_img(crop_bgr, (3, REC_INPUT_SIZE[0], REC_INPUT_SIZE[1]))[np.newaxis].copy()
def _decode_text_indices(
chars: List[str],
text_index: np.ndarray,
text_prob: np.ndarray,
is_remove_duplicate: bool = False,
) -> List[Tuple[str, float]]:
result_list: List[Tuple[str, float]] = []
ignored_tokens = [0]
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [chars[text_id] for text_id in text_index[batch_idx][selection] if text_id < len(chars)]
conf_list = text_prob[batch_idx][selection]
if len(conf_list) == 0:
conf_list = np.array([0.0], dtype=np.float32)
result_list.append(("".join(char_list), float(np.mean(conf_list).tolist())))
return result_list
def _decode_rec(logits: np.ndarray) -> Tuple[str, float]:
chars = _get_char_list()
preds = logits[-1] if isinstance(logits, (tuple, list)) else logits
preds = preds.astype(np.float32)
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
return _decode_text_indices(chars, preds_idx, preds_prob, is_remove_duplicate=True)[0]
def _recognize_crop(crop_bgr: np.ndarray) -> Tuple[str, float]:
oriented = _classify_orientation(crop_bgr)
rec_input = _preprocess_rec(oriented)
sess = _get_rec_sess()
logits = sess.run(None, {sess.get_inputs()[0].name: rec_input})[0]
return _decode_rec(logits)
def _enhance_crop(crop_bgr: np.ndarray) -> np.ndarray:
img = Image.fromarray(crop_bgr[:, :, ::-1])
img = ImageEnhance.Contrast(img).enhance(1.5)
return np.array(img)[:, :, ::-1]
def _quad_to_bbox(points: np.ndarray, width: int, height: int) -> List[int]:
min_x = max(0, int(np.floor(np.min(points[:, 0]))))
min_y = max(0, int(np.floor(np.min(points[:, 1]))))
max_x = min(width, int(np.ceil(np.max(points[:, 0]))))
max_y = min(height, int(np.ceil(np.max(points[:, 1]))))
return [min_x, min_y, max_x, max_y]
def _iou(a: List[int], b: List[int]) -> float:
ax1, ay1, ax2, ay2 = a
bx1, by1, bx2, by2 = b
ix1, iy1 = max(ax1, bx1), max(ay1, by1)
ix2, iy2 = min(ax2, bx2), min(ay2, by2)
inter = max(0, ix2 - ix1) * max(0, iy2 - iy1)
if inter == 0:
return 0.0
area_a = (ax2 - ax1) * (ay2 - ay1)
area_b = (bx2 - bx1) * (by2 - by1)
return inter / min(area_a, area_b)
def _clip_bbox(bbox: Sequence[float], width: int, height: int) -> List[int]:
x1, y1, x2, y2 = bbox
clipped = [
int(round(max(0.0, min(float(width), x1)))),
int(round(max(0.0, min(float(height), y1)))),
int(round(max(0.0, min(float(width), x2)))),
int(round(max(0.0, min(float(height), y2)))),
]
return clipped
def _dedupe_layout_regions(regions: List[dict]) -> List[dict]:
kept: List[dict] = []
for region in sorted(regions, key=lambda item: item["score"], reverse=True):
duplicate = False
for existing in kept:
same_type = existing.get("type") == region.get("type")
same_label = existing.get("raw_type") == region.get("raw_type")
if not same_type:
continue
if region.get("type") != "figure" and not same_label:
continue
if _iou(existing["bbox"], region["bbox"]) >= LAYOUT_DEDUP_IOU:
duplicate = True
break
if not duplicate:
kept.append(region)
kept.sort(key=lambda item: (item["bbox"][1], item["bbox"][0]))
return kept
def _run_ocr_full(arr_rgb: np.ndarray, threshold: float) -> List[TextBlock]:
blocks = []
arr_bgr = arr_rgb[:, :, ::-1]
height, width = arr_rgb.shape[:2]
for dt_box in _detect_boxes(arr_bgr):
crop = _get_rotate_crop_image(arr_bgr, np.asarray(dt_box, dtype=np.float32))
if crop.size == 0:
continue
text, score = _recognize_crop(crop)
if score < threshold:
retry_text, retry_score = _recognize_crop(_enhance_crop(crop))
if retry_score > score:
text, score = retry_text, retry_score
bbox = _quad_to_bbox(np.asarray(dt_box, dtype=np.float32), width, height)
blocks.append(TextBlock(text=text, score=score, bbox=bbox))
return blocks
def _assign_region(block: TextBlock, regions: list) -> tuple:
"""返回 (region_type, region_index)"""
best_type, best_iou, best_idx = "unknown", 0.0, len(regions)
for i, region in enumerate(regions):
bbox = region.get("bbox")
if bbox is None:
continue
score = _iou(block.bbox, bbox)
if score > best_iou:
best_iou = score
best_type = region.get("type", "unknown")
best_idx = i
if best_iou >= 0.9:
return best_type, best_idx
return (best_type if best_iou > 0.1 else "unknown"), best_idx
def _run_structure(arr_rgb: np.ndarray):
sess = _get_layout_sess()
h, w = arr_rgb.shape[:2]
image_input = _resize_no_pad(arr_rgb, LAYOUT_INPUT_SIZE[0], LAYOUT_INPUT_SIZE[1], LAYOUT_MEAN, LAYOUT_STD)
def _append_region(regions: List[dict], cls_id: int, score: float, bbox_values: Sequence[float], scale_to_image: bool):
if score < LAYOUT_SCORE_THRESHOLD:
return
raw_label = LAYOUT_LABELS[cls_id] if 0 <= cls_id < len(LAYOUT_LABELS) else str(cls_id)
if raw_label not in LAYOUT_ENABLED_LABELS:
return
region_type = LAYOUT_TYPE_MAP.get(raw_label, "unknown")
x1, y1, x2, y2 = map(float, bbox_values)
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
if scale_to_image:
x_scale = w / float(LAYOUT_INPUT_SIZE[1])
y_scale = h / float(LAYOUT_INPUT_SIZE[0])
bbox = _clip_bbox([x1 * x_scale, y1 * y_scale, x2 * x_scale, y2 * y_scale], w, h)
else:
bbox = _clip_bbox([x1, y1, x2, y2], w, h)
if bbox[2] - bbox[0] < 4 or bbox[3] - bbox[1] < 4:
return
region = {"type": region_type, "raw_type": raw_label, "score": score, "bbox": bbox}
if region_type == "figure":
x1c, y1c, x2c, y2c = bbox
crop = arr_rgb[max(0, y1c):max(0, y2c), max(0, x1c):max(0, x2c)]
if crop.size > 0:
region["img"] = crop
regions.append(region)
regions = []
inputs = sess.get_inputs()
if len(inputs) >= 2:
# 兼容原版 PP-DocLayout-S:输出为 [cls_id, score, x1, y1, x2, y2] + num
scale_factor = np.array([[LAYOUT_INPUT_SIZE[0] / max(h, 1), LAYOUT_INPUT_SIZE[1] / max(w, 1)]], dtype=np.float32)
outputs = sess.run(None, {inputs[0].name: image_input, inputs[1].name: scale_factor})
dets, num = outputs
num = int(num[0])
for row in dets[:num]:
_append_region(regions, int(row[0]), float(row[1]), row[2:].tolist(), scale_to_image=False)
else:
# 兼容 sim-cut 导出:输出为 boxes[N,4] + cls[C,N](均在 480x480 输入坐标系)
outputs = sess.run(None, {inputs[0].name: image_input})
if len(outputs) != 2:
raise RuntimeError(f"Unexpected layout outputs: {len(outputs)}")
out_names = [o.name for o in sess.get_outputs()]
out_map = {name: value for name, value in zip(out_names, outputs)}
if "Mul.109" in out_map and "Concat.9" in out_map:
boxes = out_map["Mul.109"][0]
cls_blob = out_map["Concat.9"]
else:
first, second = outputs
if first.ndim == 3 and first.shape[-1] == 4:
boxes, cls_blob = first[0], second
elif second.ndim == 3 and second.shape[-1] == 4:
boxes, cls_blob = second[0], first
else:
raise RuntimeError(f"Unexpected sim-cut output shapes: {first.shape}, {second.shape}")
if cls_blob.ndim != 3 or cls_blob.shape[0] != 1:
raise RuntimeError(f"Unexpected cls output shape: {cls_blob.shape}")
if cls_blob.shape[1] == len(LAYOUT_LABELS) and cls_blob.shape[2] == boxes.shape[0]:
cls_scores = cls_blob[0].transpose(1, 0) # [N, C]
elif cls_blob.shape[2] == len(LAYOUT_LABELS) and cls_blob.shape[1] == boxes.shape[0]:
cls_scores = cls_blob[0] # [N, C]
else:
raise RuntimeError(f"Unexpected cls/box alignment: cls={cls_blob.shape}, boxes={boxes.shape}")
best_cls = np.argmax(cls_scores, axis=1)
best_score = np.max(cls_scores, axis=1)
for i in range(boxes.shape[0]):
_append_region(
regions,
int(best_cls[i]),
float(best_score[i]),
boxes[i].tolist(),
scale_to_image=True,
)
return _dedupe_layout_regions(regions)
def extract(img: Image.Image, scene_result: SceneResult, debug_prefix: Optional[str] = None) -> OCRResult:
arr_rgb = np.array(img.convert("RGB"))
threshold = scene_result.ocr_threshold
if scene_result.scene == Scene.SCREENSHOT:
blocks = _run_ocr_full(arr_rgb, threshold)
figure_images, figure_bboxes = [], []
else:
structure_result = _run_structure(arr_rgb)
blocks = _run_ocr_full(arr_rgb, threshold)
figure_images, figure_bboxes = [], []
for region in structure_result:
if region.get("type") == "figure":
sub = region.get("img")
bbox = region.get("bbox")
if sub is not None and bbox is not None:
figure_images.append(Image.fromarray(sub))
figure_bboxes.append(bbox)
region_groups = {}
unknown_blocks = []
for block in blocks:
rtype, ridx = _assign_region(block, structure_result)
if rtype != "unknown":
if ridx not in region_groups:
region_groups[ridx] = (rtype, [])
region_groups[ridx][1].append(block)
else:
unknown_blocks.append(block)
region_blocks = []
for ridx, (rtype, grp) in sorted(region_groups.items()):
grp.sort(key=lambda b: (int((b.bbox[1] + b.bbox[3]) / 40), b.bbox[0]))
text = " ".join(b.text for b in grp if b.text)
score = sum(b.score for b in grp) / len(grp)
bbox = [
min(b.bbox[0] for b in grp),
min(b.bbox[1] for b in grp),
max(b.bbox[2] for b in grp),
max(b.bbox[3] for b in grp),
]
region_blocks.append((ridx, TextBlock(text=text, score=score, bbox=bbox, region_type=rtype)))
def _row_x(b: TextBlock):
return (int((b.bbox[1] + b.bbox[3]) / 40), b.bbox[0])
region_blocks.sort(key=lambda t: (t[0], *_row_x(t[1])))
sorted_region = [b for _, b in region_blocks]
unknown_blocks.sort(key=_row_x)
import bisect
rx_keys = [_row_x(b) for b in sorted_region]
for ub in unknown_blocks:
pos = bisect.bisect_left(rx_keys, _row_x(ub))
sorted_region.insert(pos, ub)
rx_keys.insert(pos, _row_x(ub))
blocks = sorted_region
blocks = [b for b in blocks if len(b.text.strip()) >= 2]
if not blocks:
text_state = (
OCRTextState.MISSING_EXPECTED_TEXT
if scene_result.scene in {Scene.SCREENSHOT, Scene.DOCUMENT}
else OCRTextState.MISSING_UNCERTAIN
)
return OCRResult(
blocks=[],
avg_score=0.0,
low_confidence=False,
text_state=text_state,
figure_images=figure_images,
figure_bboxes=figure_bboxes,
)
avg_score = sum(b.score for b in blocks) / len(blocks)
low_confidence = avg_score < 0.65 and len(blocks) < 3
text_state = OCRTextState.LOW_QUALITY if low_confidence else OCRTextState.OK
return OCRResult(
blocks=blocks,
avg_score=avg_score,
low_confidence=low_confidence,
text_state=text_state,
figure_images=figure_images,
figure_bboxes=figure_bboxes,
)