| import os.path |
| from functools import lru_cache |
| from typing import List, Tuple |
|
|
| import cv2 |
| import numpy as np |
| from huggingface_hub import HfApi, HfFileSystem, hf_hub_download |
| from imgutils.data import ImageTyping |
| from imgutils.utils import open_onnx_model |
|
|
| hf_client = HfApi() |
| hf_fs = HfFileSystem() |
|
|
|
|
| @lru_cache() |
| def _get_available_models(): |
| for f in hf_fs.glob('deepghs/text_detection/*/end2end.onnx'): |
| yield os.path.relpath(f, 'deepghs/text_detection').split('/')[0] |
|
|
|
|
| _ALL_MODELS = list(_get_available_models()) |
| _DEFAULT_MODEL = 'dbnetpp_resnet50_fpnc_1200e_icdar2015' |
|
|
|
|
| @lru_cache() |
| def _get_onnx_session(model): |
| return open_onnx_model(hf_hub_download( |
| 'deepghs/text_detection', |
| f'{model}/end2end.onnx' |
| )) |
|
|
|
|
| def _get_heatmap_of_text(image: ImageTyping, model: str) -> np.ndarray: |
| origin_width, origin_height = width, height = image.size |
| align = 32 |
| if width % align != 0: |
| width += (align - width % align) |
| if height % align != 0: |
| height += (align - height % align) |
|
|
| input_ = np.array(image).transpose((2, 0, 1)).astype(np.float32) / 255.0 |
| |
| input_ = np.pad(input_[None, ...], ((0, 0), (0, 0), (0, height - origin_height), (0, width - origin_width))) |
|
|
| def _normalize(data, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)): |
| mean, std = np.asarray(mean), np.asarray(std) |
| return (data - mean[None, :, None, None]) / std[None, :, None, None] |
|
|
| ort = _get_onnx_session(model) |
|
|
| input_ = _normalize(input_).astype(np.float32) |
| output_, = ort.run(['output'], {'input': input_}) |
| heatmap = output_[0] |
| heatmap = heatmap[:origin_height, :origin_width] |
|
|
| return heatmap |
|
|
|
|
| def _get_bounding_box_of_text(image: ImageTyping, model: str, threshold: float) \ |
| -> List[Tuple[Tuple[int, int, int, int], float]]: |
| heatmap = _get_heatmap_of_text(image, model) |
| c_rets = cv2.findContours((heatmap * 255.0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| contours = c_rets[0] if len(c_rets) == 2 else c_rets[1] |
| bboxes = [] |
| for c in contours: |
| x, y, w, h = cv2.boundingRect(c) |
| x0, y0, x1, y1 = x, y, x + w, y + h |
| score = heatmap[y0:y1, x0:x1].mean().item() |
| if score >= threshold: |
| bboxes.append(((x0, y0, x1, y1), score)) |
|
|
| return bboxes |
|
|
|
|
| def detect_text(image: ImageTyping, model: str = _DEFAULT_MODEL, threshold: float = 0.05): |
| bboxes = [] |
| for (x0, y0, x1, y1), score in _get_bounding_box_of_text(image, model, threshold): |
| bboxes.append(((x0, y0, x1, y1), 'text', score)) |
| return bboxes |
|
|