diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..2403af6c3697ca08cc854b84622906f050620f69 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,73 @@ +# Python bytecode and cache +__pycache__/ +*.pyc +*.pyo +*.pyd +.pytest_cache/ +.coverage + +# IDE/editor files +.idea/ +.vscode/ +*.swp +*.swo + +# Build and distribution outputs +dist/ +build/ +out/ + +# Test and debug artifacts +test/ +tests/ +debug/ +*.log + +# Version control +.git/ +.gitignore +.svn/ +.hg/ + +# Environment and secrets +.env* +*.env +*.pem +*.key +*.crt +config.local.* +*.local.yml + +# Documentation and markdown +README* +*.md +docs/ + +# Docker and compose files +Dockerfile* +docker-compose* + +# Temporary and local files +tmp/ +temp/ +*.tmp +.local/ +local/ + +# Backup files +*.bak + +# Project-specific: results and models (if not needed in build) +results/ + +# Exclude empty model placeholder, but keep actual models +models/.gitkeep + +# Allow ONNX models to be included (remove models/ if you want to exclude all models) +# models/ + +# Miscellaneous +*.DS_Store + +# Exclude self +.dockerignore diff --git a/.dockerignore.bak b/.dockerignore.bak new file mode 100644 index 0000000000000000000000000000000000000000..cce144f07546f82be499c0e5e19dcc3f94307e37 --- /dev/null +++ b/.dockerignore.bak @@ -0,0 +1,68 @@ +# Python bytecode and cache +__pycache__/ +*.pyc +*.pyo +*.pyd +.pytest_cache/ +.coverage + +# Development and IDE artifacts +.idea/ +.vscode/ +*.swp +*.swo + +# Build outputs +build/ +dist/ +out/ + +# Test and debug files +test/ +tests/ +debug/ +*.log + +# Version control +.git/ +.gitignore +.svn/ +.hg/ + +# Environment and secrets +.env* +*.env +*.pem +*.key +*.crt +config.local.* +*.local.yml + +# Documentation +README* +*.md +docs/ + +# Docker files (do not ignore Dockerfile and docker-compose.yml in root) +Dockerfile* +docker-compose* +!Dockerfile +!docker-compose.yml + +# Temporary and local files +tmp/ +temp/ +*.tmp +.local/ +local/ + +# Project-specific: results and model artifacts (keep models/ for .onnx, ignore results/) +results/ + +# Miscellaneous +*.bak +*.orig +*.old + +# Exclude .gitkeep from ignore (if needed for empty dirs) +!models/.gitkeep diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b75232095104c782080867765c324a0277f02fce 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +tmp/uploaded_1745856378.pdf filter=lfs diff=lfs merge=lfs -text diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3f440e462e7e77aeca2ceca80fa8e44a841b0f0c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +# syntax=docker/dockerfile:1.4 + +# Base image +FROM python:3.12-slim AS base + +# Install system dependencies for OpenCV and ffmpeg at runtime +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libx11-dev \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +# Builder stage: install Python dependencies +WORKDIR /app +# 4. Copy requirements and install +COPY requirements.txt . +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --default-timeout=100 --no-cache-dir -r requirements.txt +# 5. Copy app source +COPY . . + +EXPOSE 7860 +CMD ["python", "demo_v5.py"] \ No newline at end of file diff --git a/README.md b/README.md index 8dd7f27d99e963804dd32d6c2601d26042ad7ff7..7f9c502a28c0f44c19a470aa4c80303d728538c4 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,6 @@ --- -title: Rapidocr Ort -emoji: 🏆 -colorFrom: green -colorTo: blue +title: rapidocr_ort +app_file: demo_v5.py sdk: gradio -sdk_version: 5.27.1 -app_file: app.py -pinned: false +sdk_version: 5.27.0 --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c97862094504f2c31d84ddf4911ac8789e79ff79 --- /dev/null +++ b/__init__.py @@ -0,0 +1,5 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .main import RapidOCR +from .utils import LoadImageError, VisRes diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63f86d84f345e91387ec363ba21d4b33fb5e16ab Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/image_enhancement.cpython-310.pyc b/__pycache__/image_enhancement.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c85bcb55b3ee8455a9e1ba88d6537365d2bcb6ae Binary files /dev/null and b/__pycache__/image_enhancement.cpython-310.pyc differ diff --git a/__pycache__/image_enhancement.cpython-311.pyc b/__pycache__/image_enhancement.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4285d58ef10ee3bca0e2e68eaa7a7acff25c1742 Binary files /dev/null and b/__pycache__/image_enhancement.cpython-311.pyc differ diff --git a/__pycache__/image_enhancement.cpython-313.pyc b/__pycache__/image_enhancement.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90f3d6ec31cf7499a40d1e611b6113b9ce3c09ef Binary files /dev/null and b/__pycache__/image_enhancement.cpython-313.pyc differ diff --git a/__pycache__/main.cpython-310.pyc b/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..680dc0e7e66f929752f63df71bf9bdae324d62dd Binary files /dev/null and b/__pycache__/main.cpython-310.pyc differ diff --git a/__pycache__/main.cpython-311.pyc b/__pycache__/main.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75c3fdfb20d1cfec7d73c335ebd3c42ea636448d Binary files /dev/null and b/__pycache__/main.cpython-311.pyc differ diff --git a/__pycache__/main.cpython-313.pyc b/__pycache__/main.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8508a07447682018edabd5b0c9907f3a2aab8fa9 Binary files /dev/null and b/__pycache__/main.cpython-313.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f460186d219720f876e28fcea1b24ea76e6330 --- /dev/null +++ b/app.py @@ -0,0 +1,29 @@ +import gradio as gr +import numpy as np +import cv2 +from pathlib import Path + +from main import RapidOCR + + +ocr_engine = RapidOCR() + +def extract_text_from_bottom(image: np.ndarray): + h = image.shape[0] +# bottom_crop = image[int(h * 0.7):, :] + result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True) + if not result: + return "No text found." + + texts = [r[1] for r in result] + return "\n".join(texts) + +demo = gr.Interface( + fn=extract_text_from_bottom, + inputs=gr.Image(type="numpy"), + outputs="text", + title="", + description="", +) + +demo.launch() diff --git a/cal_rec_boxes/__init__.py b/cal_rec_boxes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5127715f9d0c134b975769fa2135eee450adeba8 --- /dev/null +++ b/cal_rec_boxes/__init__.py @@ -0,0 +1,4 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .main import CalRecBoxes diff --git a/cal_rec_boxes/__pycache__/__init__.cpython-310.pyc b/cal_rec_boxes/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20584fcee36ae53110e67872892080c760a5261 Binary files /dev/null and b/cal_rec_boxes/__pycache__/__init__.cpython-310.pyc differ diff --git a/cal_rec_boxes/__pycache__/__init__.cpython-311.pyc b/cal_rec_boxes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebc9df819752f5a75a54ee922323186a50122e74 Binary files /dev/null and b/cal_rec_boxes/__pycache__/__init__.cpython-311.pyc differ diff --git a/cal_rec_boxes/__pycache__/__init__.cpython-312.pyc b/cal_rec_boxes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23a6e5489912e3e062f1a68dbb42678021b3b7b7 Binary files /dev/null and b/cal_rec_boxes/__pycache__/__init__.cpython-312.pyc differ diff --git a/cal_rec_boxes/__pycache__/__init__.cpython-313.pyc b/cal_rec_boxes/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e044188f026082f26139851da33cac53f976bf60 Binary files /dev/null and b/cal_rec_boxes/__pycache__/__init__.cpython-313.pyc differ diff --git a/cal_rec_boxes/__pycache__/main.cpython-310.pyc b/cal_rec_boxes/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e17f9b710c525248cafe2984967d1dd01c3a0e62 Binary files /dev/null and b/cal_rec_boxes/__pycache__/main.cpython-310.pyc differ diff --git a/cal_rec_boxes/__pycache__/main.cpython-311.pyc b/cal_rec_boxes/__pycache__/main.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e300d0072b28fec43a008bc9b311d322251034fd Binary files /dev/null and b/cal_rec_boxes/__pycache__/main.cpython-311.pyc differ diff --git a/cal_rec_boxes/__pycache__/main.cpython-312.pyc b/cal_rec_boxes/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7390b0e9ec3cc7b7f8eec03dd5428b970ce4d39 Binary files /dev/null and b/cal_rec_boxes/__pycache__/main.cpython-312.pyc differ diff --git a/cal_rec_boxes/__pycache__/main.cpython-313.pyc b/cal_rec_boxes/__pycache__/main.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbbb4ba850f01927fd6f71ce6fadf717ada13196 Binary files /dev/null and b/cal_rec_boxes/__pycache__/main.cpython-313.pyc differ diff --git a/cal_rec_boxes/main.py b/cal_rec_boxes/main.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ea5da4c74ddbe95c9cd51ca227dc407428ab65 --- /dev/null +++ b/cal_rec_boxes/main.py @@ -0,0 +1,281 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL / Joker1212 +# @Contact: liekkaskono@163.com +import copy +import math +from typing import Any, List, Optional, Tuple + +import cv2 +import numpy as np + + +class CalRecBoxes: + """计算识别文字的汉字单字和英文单词的坐标框。代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection""" + + def __init__(self): + pass + + def __call__( + self, + imgs: Optional[List[np.ndarray]], + dt_boxes: Optional[List[np.ndarray]], + rec_res: Optional[List[Any]], + ): + res = [] + for img, box, rec_res in zip(imgs, dt_boxes, rec_res): + direction = self.get_box_direction(box) + + rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2] + h, w = img.shape[:2] + img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]]) + word_box_content_list, word_box_list, conf_list = self.cal_ocr_word_box( + rec_txt, img_box, rec_word_info + ) + word_box_list = self.adjust_box_overlap(copy.deepcopy(word_box_list)) + word_box_list = self.reverse_rotate_crop_image( + copy.deepcopy(box), word_box_list, direction + ) + res.append( + [rec_txt, rec_conf, word_box_list, word_box_content_list, conf_list] + ) + return res + + @staticmethod + def get_box_direction(box: np.ndarray) -> str: + direction = "w" + img_crop_width = int( + max( + np.linalg.norm(box[0] - box[1]), + np.linalg.norm(box[2] - box[3]), + ) + ) + img_crop_height = int( + max( + np.linalg.norm(box[0] - box[3]), + np.linalg.norm(box[1] - box[2]), + ) + ) + if img_crop_height * 1.0 / img_crop_width >= 1.5: + direction = "h" + return direction + + @staticmethod + def cal_ocr_word_box( + rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]] + ) -> Tuple[List[str], List[List[int]], List[float]]: + """Calculate the detection frame for each word based on the results of recognition and detection of ocr + 汉字坐标是单字的 + 英语坐标是单词级别的 + """ + + col_num, word_list, word_col_list, state_list, conf_list = rec_word_info + box = box.tolist() + bbox_x_start = box[0][0] + bbox_x_end = box[1][0] + bbox_y_start = box[0][1] + bbox_y_end = box[2][1] + + cell_width = (bbox_x_end - bbox_x_start) / col_num + word_box_list = [] + word_box_content_list = [] + cn_width_list = [] + en_width_list = [] + cn_col_list = [] + en_col_list = [] + + def cal_char_width(width_list, word_col_): + if len(word_col_) == 1: + return + char_total_length = (word_col_[-1] - word_col_[0]) * cell_width + char_width = char_total_length / (len(word_col_) - 1) + width_list.append(char_width) + + def cal_box(col_list, width_list, word_box_list_): + if len(col_list) == 0: + return + if len(width_list) != 0: + avg_char_width = np.mean(width_list) + else: + avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_txt) + + for center_idx in col_list: + center_x = (center_idx + 0.5) * cell_width + cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start + cell_x_end = ( + min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start) + + bbox_x_start + ) + cell = [ + [cell_x_start, bbox_y_start], + [cell_x_end, bbox_y_start], + [cell_x_end, bbox_y_end], + [cell_x_start, bbox_y_end], + ] + word_box_list_.append(cell) + + for word, word_col, state in zip(word_list, word_col_list, state_list): + if state == "cn": + cal_char_width(cn_width_list, word_col) + cn_col_list += word_col + word_box_content_list += word + else: + cal_char_width(en_width_list, word_col) + en_col_list += word_col + word_box_content_list += word + + cal_box(cn_col_list, cn_width_list, word_box_list) + cal_box(en_col_list, en_width_list, word_box_list) + sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0]) + return word_box_content_list, sorted_word_box_list, conf_list + + @staticmethod + def adjust_box_overlap( + word_box_list: List[List[List[int]]], + ) -> List[List[List[int]]]: + # 调整bbox有重叠的地方 + for i in range(len(word_box_list) - 1): + cur, nxt = word_box_list[i], word_box_list[i + 1] + if cur[1][0] > nxt[0][0]: # 有交集 + distance = abs(cur[1][0] - nxt[0][0]) + cur[1][0] -= distance / 2 + cur[2][0] -= distance / 2 + nxt[0][0] += distance - distance / 2 + nxt[3][0] += distance - distance / 2 + return word_box_list + + def reverse_rotate_crop_image( + self, + bbox_points: np.ndarray, + word_points_list: List[List[List[int]]], + direction: str = "w", + ) -> List[List[List[int]]]: + """ + get_rotate_crop_image的逆操作 + img为原图 + part_img为crop后的图 + bbox_points为part_img中对应在原图的bbox, 四个点,左上,右上,右下,左下 + part_points为在part_img中的点[(x, y), (x, y)] + """ + bbox_points = np.float32(bbox_points) + + left = int(np.min(bbox_points[:, 0])) + top = int(np.min(bbox_points[:, 1])) + bbox_points[:, 0] = bbox_points[:, 0] - left + bbox_points[:, 1] = bbox_points[:, 1] - top + + img_crop_width = int(np.linalg.norm(bbox_points[0] - bbox_points[1])) + img_crop_height = int(np.linalg.norm(bbox_points[0] - bbox_points[3])) + + pts_std = np.array( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ).astype(np.float32) + M = cv2.getPerspectiveTransform(bbox_points, pts_std) + _, IM = cv2.invert(M) + + new_word_points_list = [] + for word_points in word_points_list: + new_word_points = [] + for point in word_points: + new_point = point + if direction == "h": + new_point = self.s_rotate( + math.radians(-90), new_point[0], new_point[1], 0, 0 + ) + new_point[0] = new_point[0] + img_crop_width + + p = np.float32(new_point + [1]) + x, y, z = np.dot(IM, p) + new_point = [x / z, y / z] + + new_point = [int(new_point[0] + left), int(new_point[1] + top)] + new_word_points.append(new_point) + new_word_points = self.order_points(new_word_points) + new_word_points_list.append(new_word_points) + return new_word_points_list + + @staticmethod + def s_rotate(angle, valuex, valuey, pointx, pointy): + """绕pointx,pointy顺时针旋转 + https://blog.csdn.net/qq_38826019/article/details/84233397 + """ + valuex = np.array(valuex) + valuey = np.array(valuey) + sRotatex = ( + (valuex - pointx) * math.cos(angle) + + (valuey - pointy) * math.sin(angle) + + pointx + ) + sRotatey = ( + (valuey - pointy) * math.cos(angle) + - (valuex - pointx) * math.sin(angle) + + pointy + ) + return [sRotatex, sRotatey] + + @staticmethod + def order_points(box: List[List[int]]) -> List[List[int]]: + """矩形框顺序排列""" + + def convert_to_1x2(p): + if p.shape == (2,): + return p.reshape((1, 2)) + elif p.shape == (1, 2): + return p + else: + return p[:1, :] + + box = np.array(box).reshape((-1, 2)) + center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1]) + if np.any(box[:, 0] == center_x) and np.any( + box[:, 1] == center_y + ): # 有两点横坐标相等,有两点纵坐标相等,菱形 + p1 = box[np.where(box[:, 0] == np.min(box[:, 0]))] + p2 = box[np.where(box[:, 1] == np.min(box[:, 1]))] + p3 = box[np.where(box[:, 0] == np.max(box[:, 0]))] + p4 = box[np.where(box[:, 1] == np.max(box[:, 1]))] + elif np.all(box[:, 0] == center_x): # 四个点的横坐标都相同 + y_sort = np.argsort(box[:, 1]) + p1 = box[y_sort[0]] + p2 = box[y_sort[1]] + p3 = box[y_sort[2]] + p4 = box[y_sort[3]] + elif np.any(box[:, 0] == center_x) and np.all( + box[:, 1] != center_y + ): # 只有两点横坐标相等,先上下再左右 + p12, p34 = ( + box[np.where(box[:, 1] < center_y)], + box[np.where(box[:, 1] > center_y)], + ) + p1, p2 = ( + p12[np.where(p12[:, 0] == np.min(p12[:, 0]))], + p12[np.where(p12[:, 0] == np.max(p12[:, 0]))], + ) + p3, p4 = ( + p34[np.where(p34[:, 0] == np.max(p34[:, 0]))], + p34[np.where(p34[:, 0] == np.min(p34[:, 0]))], + ) + else: # 只有两点纵坐标相等,或者是没有相等的,先左右再上下 + p14, p23 = ( + box[np.where(box[:, 0] < center_x)], + box[np.where(box[:, 0] > center_x)], + ) + p1, p4 = ( + p14[np.where(p14[:, 1] == np.min(p14[:, 1]))], + p14[np.where(p14[:, 1] == np.max(p14[:, 1]))], + ) + p2, p3 = ( + p23[np.where(p23[:, 1] == np.min(p23[:, 1]))], + p23[np.where(p23[:, 1] == np.max(p23[:, 1]))], + ) + + # 解决单字切割后横坐标完全相同的shape错误 + p1 = convert_to_1x2(p1) + p2 = convert_to_1x2(p2) + p3 = convert_to_1x2(p3) + p4 = convert_to_1x2(p4) + return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist() diff --git a/ch_ppocr_cls/__init__.py b/ch_ppocr_cls/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc854c61ce89812032d3f45eb689d04844b0020 --- /dev/null +++ b/ch_ppocr_cls/__init__.py @@ -0,0 +1,4 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .text_cls import TextClassifier diff --git a/ch_ppocr_cls/__pycache__/__init__.cpython-310.pyc b/ch_ppocr_cls/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10db01d729d7ad56731437d799c18b8feab57d36 Binary files /dev/null and b/ch_ppocr_cls/__pycache__/__init__.cpython-310.pyc differ diff --git a/ch_ppocr_cls/__pycache__/__init__.cpython-311.pyc b/ch_ppocr_cls/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7d8477d406312c52d0e6bd59133ed2887ff2a14 Binary files /dev/null and b/ch_ppocr_cls/__pycache__/__init__.cpython-311.pyc differ diff --git a/ch_ppocr_cls/__pycache__/__init__.cpython-312.pyc b/ch_ppocr_cls/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1522a0955baa27baf6dd31d6bc93e33608bdd74 Binary files /dev/null and b/ch_ppocr_cls/__pycache__/__init__.cpython-312.pyc differ diff --git a/ch_ppocr_cls/__pycache__/__init__.cpython-313.pyc b/ch_ppocr_cls/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8ef5b6ee8ca02e8070e057b27082694b2598ef6 Binary files /dev/null and b/ch_ppocr_cls/__pycache__/__init__.cpython-313.pyc differ diff --git a/ch_ppocr_cls/__pycache__/text_cls.cpython-310.pyc b/ch_ppocr_cls/__pycache__/text_cls.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..980a8fe0feb49ab8526fbc9505a704932227e6ca Binary files /dev/null and b/ch_ppocr_cls/__pycache__/text_cls.cpython-310.pyc differ diff --git a/ch_ppocr_cls/__pycache__/text_cls.cpython-311.pyc b/ch_ppocr_cls/__pycache__/text_cls.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fede6855fd360532045278b4b63afdb86d1b70b7 Binary files /dev/null and b/ch_ppocr_cls/__pycache__/text_cls.cpython-311.pyc differ diff --git a/ch_ppocr_cls/__pycache__/text_cls.cpython-312.pyc b/ch_ppocr_cls/__pycache__/text_cls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..731d4b0d7d1581675bf061ee8c8beba3b0a1c8ae Binary files /dev/null and b/ch_ppocr_cls/__pycache__/text_cls.cpython-312.pyc differ diff --git a/ch_ppocr_cls/__pycache__/text_cls.cpython-313.pyc b/ch_ppocr_cls/__pycache__/text_cls.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bbcfc23a96554d12396d8a0fca36bfbd3d2cbd7 Binary files /dev/null and b/ch_ppocr_cls/__pycache__/text_cls.cpython-313.pyc differ diff --git a/ch_ppocr_cls/__pycache__/utils.cpython-310.pyc b/ch_ppocr_cls/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9b9d9ef1ceb233abbbabc7f4c7d0bc7797c910b Binary files /dev/null and b/ch_ppocr_cls/__pycache__/utils.cpython-310.pyc differ diff --git a/ch_ppocr_cls/__pycache__/utils.cpython-311.pyc b/ch_ppocr_cls/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a0cc4bb0dff3478a82c98ae8edb9e827fcefbf Binary files /dev/null and b/ch_ppocr_cls/__pycache__/utils.cpython-311.pyc differ diff --git a/ch_ppocr_cls/__pycache__/utils.cpython-312.pyc b/ch_ppocr_cls/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c9da2ca636e3c59d677e700197b0cfc626b86a7 Binary files /dev/null and b/ch_ppocr_cls/__pycache__/utils.cpython-312.pyc differ diff --git a/ch_ppocr_cls/__pycache__/utils.cpython-313.pyc b/ch_ppocr_cls/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e141494a8c4ef25a79bf1f11a759cf25264e0fe Binary files /dev/null and b/ch_ppocr_cls/__pycache__/utils.cpython-313.pyc differ diff --git a/ch_ppocr_cls/text_cls.py b/ch_ppocr_cls/text_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..bc974c31e204e80cdcc029d024a8879ded23ab04 --- /dev/null +++ b/ch_ppocr_cls/text_cls.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import copy +import math +import time +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +from rapidocr_onnxruntime.utils import OrtInferSession, read_yaml + +from .utils import ClsPostProcess + + +class TextClassifier: + def __init__(self, config: Dict[str, Any]): + self.cls_image_shape = config["cls_image_shape"] + self.cls_batch_num = config["cls_batch_num"] + self.cls_thresh = config["cls_thresh"] + self.postprocess_op = ClsPostProcess(config["label_list"]) + + self.infer = OrtInferSession(config) + + def __call__( + self, img_list: Union[np.ndarray, List[np.ndarray]] + ) -> Tuple[List[np.ndarray], List[List[Union[str, float]]], float]: + if isinstance(img_list, np.ndarray): + img_list = [img_list] + + img_list = copy.deepcopy(img_list) + + # Calculate the aspect ratio of all text bars + width_list = [img.shape[1] / float(img.shape[0]) for img in img_list] + + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + img_num = len(img_list) + cls_res = [["", 0.0]] * img_num + batch_num = self.cls_batch_num + elapse = 0 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + + norm_img_batch = [] + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) + + starttime = time.time() + prob_out = self.infer(norm_img_batch)[0] + cls_result = self.postprocess_op(prob_out) + elapse += time.time() - starttime + + for rno, (label, score) in enumerate(cls_result): + cls_res[indices[beg_img_no + rno]] = [label, score] + if "180" in label and score > self.cls_thresh: + img_list[indices[beg_img_no + rno]] = cv2.rotate( + img_list[indices[beg_img_no + rno]], 1 + ) + return img_list, cls_res, elapse + + def resize_norm_img(self, img: np.ndarray) -> np.ndarray: + img_c, img_h, img_w = self.cls_image_shape + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(img_h * ratio) > img_w: + resized_w = img_w + else: + resized_w = int(math.ceil(img_h * ratio)) + + resized_image = cv2.resize(img, (resized_w, img_h)) + resized_image = resized_image.astype("float32") + if img_c == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((img_c, img_h, img_w), dtype=np.float32) + padding_im[:, :, :resized_w] = resized_image + return padding_im + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image_path", type=str, help="image_dir|image_path") + parser.add_argument("--config_path", type=str, default="config.yaml") + args = parser.parse_args() + + config = read_yaml(args.config_path) + + text_classifier = TextClassifier(config) + + img = cv2.imread(args.image_path) + img_list, cls_res, predict_time = text_classifier(img) + for ino in range(len(img_list)): + print(f"cls result:{cls_res[ino]}") diff --git a/ch_ppocr_cls/utils.py b/ch_ppocr_cls/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6549fcb2b09b661a788ef839a9dc5d8fcab040ad --- /dev/null +++ b/ch_ppocr_cls/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +import numpy as np + + +class ClsPostProcess: + def __init__(self, label_list: List[str]): + self.label_list = label_list + + def __call__(self, preds: np.ndarray) -> List[Tuple[str, float]]: + pred_idxs = preds.argmax(axis=1) + decode_out = [ + (self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs) + ] + return decode_out diff --git a/ch_ppocr_det/__init__.py b/ch_ppocr_det/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..502f4a3dfddb6814b88c761e2d73dc49d76f11c3 --- /dev/null +++ b/ch_ppocr_det/__init__.py @@ -0,0 +1,4 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .text_detect import TextDetector diff --git a/ch_ppocr_det/__pycache__/__init__.cpython-310.pyc b/ch_ppocr_det/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e177012080b1e6ce9971aed15a6dbbd953df9df Binary files /dev/null and b/ch_ppocr_det/__pycache__/__init__.cpython-310.pyc differ diff --git a/ch_ppocr_det/__pycache__/__init__.cpython-311.pyc b/ch_ppocr_det/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e8925b912d773751463046fe57240376a9c76d5 Binary files /dev/null and b/ch_ppocr_det/__pycache__/__init__.cpython-311.pyc differ diff --git a/ch_ppocr_det/__pycache__/__init__.cpython-312.pyc b/ch_ppocr_det/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d41a449f89722250630406dca4af2359b44a05f9 Binary files /dev/null and b/ch_ppocr_det/__pycache__/__init__.cpython-312.pyc differ diff --git a/ch_ppocr_det/__pycache__/__init__.cpython-313.pyc b/ch_ppocr_det/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e87fd535096731b627cc206c11e997a5043400e Binary files /dev/null and b/ch_ppocr_det/__pycache__/__init__.cpython-313.pyc differ diff --git a/ch_ppocr_det/__pycache__/text_detect.cpython-310.pyc b/ch_ppocr_det/__pycache__/text_detect.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab5c1eb093d5b49a0b45fc3bc887cabea7a5d7ab Binary files /dev/null and b/ch_ppocr_det/__pycache__/text_detect.cpython-310.pyc differ diff --git a/ch_ppocr_det/__pycache__/text_detect.cpython-311.pyc b/ch_ppocr_det/__pycache__/text_detect.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a801e07682f17eab7f5f3572e877f9a4052fcc49 Binary files /dev/null and b/ch_ppocr_det/__pycache__/text_detect.cpython-311.pyc differ diff --git a/ch_ppocr_det/__pycache__/text_detect.cpython-312.pyc b/ch_ppocr_det/__pycache__/text_detect.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b586bec385278be3b126458cfa75ae21937ee8c Binary files /dev/null and b/ch_ppocr_det/__pycache__/text_detect.cpython-312.pyc differ diff --git a/ch_ppocr_det/__pycache__/text_detect.cpython-313.pyc b/ch_ppocr_det/__pycache__/text_detect.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad33d61df8e5730b323038606523baba091f45d6 Binary files /dev/null and b/ch_ppocr_det/__pycache__/text_detect.cpython-313.pyc differ diff --git a/ch_ppocr_det/__pycache__/utils.cpython-310.pyc b/ch_ppocr_det/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a7ad5709149e166f42868c70289334cd5dcccf Binary files /dev/null and b/ch_ppocr_det/__pycache__/utils.cpython-310.pyc differ diff --git a/ch_ppocr_det/__pycache__/utils.cpython-311.pyc b/ch_ppocr_det/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c45009a00457ae589b1fb7757754d87aa1625140 Binary files /dev/null and b/ch_ppocr_det/__pycache__/utils.cpython-311.pyc differ diff --git a/ch_ppocr_det/__pycache__/utils.cpython-312.pyc b/ch_ppocr_det/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e5ffc4464382f63197fa72bf49eced6d8cac54a Binary files /dev/null and b/ch_ppocr_det/__pycache__/utils.cpython-312.pyc differ diff --git a/ch_ppocr_det/__pycache__/utils.cpython-313.pyc b/ch_ppocr_det/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c715051f15fe2dd426e8fc581a4d5d7d098f3658 Binary files /dev/null and b/ch_ppocr_det/__pycache__/utils.cpython-313.pyc differ diff --git a/ch_ppocr_det/text_detect.py b/ch_ppocr_det/text_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..5632d79084b6331d3ac274b8104b739f52ca248d --- /dev/null +++ b/ch_ppocr_det/text_detect.py @@ -0,0 +1,124 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import time +from typing import Any, Dict, Optional, Tuple + +import numpy as np + +from rapidocr_onnxruntime.utils import OrtInferSession + +from .utils import DBPostProcess, DetPreProcess + + +class TextDetector: + def __init__(self, config: Dict[str, Any]): + self.limit_side_len = config.get("limit_side_len") + self.limit_type = config.get("limit_type") + self.mean = config.get("mean") + self.std = config.get("std") + self.preprocess_op = None + + post_process = { + "thresh": config.get("thresh", 0.3), + "box_thresh": config.get("box_thresh", 0.5), + "max_candidates": config.get("max_candidates", 1000), + "unclip_ratio": config.get("unclip_ratio", 1.6), + "use_dilation": config.get("use_dilation", True), + "score_mode": config.get("score_mode", "fast"), + } + self.postprocess_op = DBPostProcess(**post_process) + + self.infer = OrtInferSession(config) + + def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]: + start_time = time.perf_counter() + + if img is None: + raise ValueError("img is None") + + ori_img_shape = img.shape[0], img.shape[1] + self.preprocess_op = self.get_preprocess(max(img.shape[0], img.shape[1])) + prepro_img = self.preprocess_op(img) + if prepro_img is None: + return None, 0 + + preds = self.infer(prepro_img)[0] + dt_boxes, dt_boxes_scores = self.postprocess_op(preds, ori_img_shape) + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_img_shape) + elapse = time.perf_counter() - start_time + return dt_boxes, elapse + + def get_preprocess(self, max_wh): + if self.limit_type == "min": + limit_side_len = self.limit_side_len + elif max_wh < 960: + limit_side_len = 960 + elif max_wh < 1500: + limit_side_len = 1500 + else: + limit_side_len = 2000 + return DetPreProcess(limit_side_len, self.limit_type, self.mean, self.std) + + def filter_tag_det_res( + self, dt_boxes: np.ndarray, image_shape: Tuple[int, int] + ) -> np.ndarray: + img_height, img_width = image_shape + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + + dt_boxes_new.append(box) + return np.array(dt_boxes_new) + + def order_points_clockwise(self, pts: np.ndarray) -> np.ndarray: + """ + reference from: + https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py + sort the points based on their x-coordinates + """ + xSorted = pts[np.argsort(pts[:, 0]), :] + + # grab the left-most and right-most points from the sorted + # x-roodinate points + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + + # now, sort the left-most coordinates according to their + # y-coordinates so we can grab the top-left and bottom-left + # points, respectively + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + + rightMost = rightMost[np.argsort(rightMost[:, 1]), :] + (tr, br) = rightMost + + rect = np.array([tl, tr, br, bl], dtype="float32") + return rect + + def clip_det_res( + self, points: np.ndarray, img_height: int, img_width: int + ) -> np.ndarray: + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points diff --git a/ch_ppocr_det/utils.py b/ch_ppocr_det/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef410631c7e16e279bb7f15e30e71d1ee2fd6937 --- /dev/null +++ b/ch_ppocr_det/utils.py @@ -0,0 +1,237 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from typing import List, Optional, Tuple + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + + +class DetPreProcess: + def __init__( + self, limit_side_len: int = 736, limit_type: str = "min", mean=None, std=None + ): + if mean is None: + mean = [0.5, 0.5, 0.5] + + if std is None: + std = [0.5, 0.5, 0.5] + + self.mean = np.array(mean) + self.std = np.array(std) + self.scale = 1 / 255.0 + + self.limit_side_len = limit_side_len + self.limit_type = limit_type + + def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: + resized_img = self.resize(img) + if resized_img is None: + return None + + img = self.normalize(resized_img) + img = self.permute(img) + img = np.expand_dims(img, axis=0).astype(np.float32) + return img + + def normalize(self, img: np.ndarray) -> np.ndarray: + return (img.astype("float32") * self.scale - self.mean) / self.std + + def permute(self, img: np.ndarray) -> np.ndarray: + return img.transpose((2, 0, 1)) + + def resize(self, img: np.ndarray) -> Optional[np.ndarray]: + """resize image to a size multiple of 32 which is required by the network""" + h, w = img.shape[:2] + + if self.limit_type == "max": + if max(h, w) > self.limit_side_len: + if h > w: + ratio = float(self.limit_side_len) / h + else: + ratio = float(self.limit_side_len) / w + else: + ratio = 1.0 + else: + if min(h, w) < self.limit_side_len: + if h < w: + ratio = float(self.limit_side_len) / h + else: + ratio = float(self.limit_side_len) / w + else: + ratio = 1.0 + + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = int(round(resize_h / 32) * 32) + resize_w = int(round(resize_w / 32) * 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except Exception as exc: + raise ResizeImgError from exc + + return img + + +class ResizeImgError(Exception): + pass + + +class DBPostProcess: + """The post process for Differentiable Binarization (DB).""" + + def __init__( + self, + thresh: float = 0.3, + box_thresh: float = 0.7, + max_candidates: int = 1000, + unclip_ratio: float = 2.0, + score_mode: str = "fast", + use_dilation: bool = False, + ): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.score_mode = score_mode + + self.dilation_kernel = None + if use_dilation: + self.dilation_kernel = np.array([[1, 1], [1, 1]]) + + def __call__( + self, pred: np.ndarray, ori_shape: Tuple[int, int] + ) -> Tuple[np.ndarray, List[float]]: + src_h, src_w = ori_shape + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + mask = segmentation[0] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[0]).astype(np.uint8), self.dilation_kernel + ) + boxes, scores = self.boxes_from_bitmap(pred[0], mask, src_w, src_h) + return boxes, scores + + def boxes_from_bitmap( + self, pred: np.ndarray, bitmap: np.ndarray, dest_width: int, dest_height: int + ) -> Tuple[np.ndarray, List[float]]: + """ + bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + + height, width = bitmap.shape + + outs = cv2.findContours( + (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE + ) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes, scores = [], [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) + + if self.box_thresh > score: + continue + + box = self.unclip(points) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height + ) + boxes.append(box.astype(np.int32)) + scores.append(score) + return np.array(boxes, dtype=np.int32), scores + + def get_mini_boxes(self, contour: np.ndarray) -> Tuple[np.ndarray, float]: + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = np.array( + [points[index_1], points[index_2], points[index_3], points[index_4]] + ) + return box, min(bounding_box[1]) + + @staticmethod + def box_score_fast(bitmap: np.ndarray, _box: np.ndarray) -> float: + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + def box_score_slow(self, bitmap: np.ndarray, contour: np.ndarray) -> float: + """use polyon mean score as the mean score""" + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + def unclip(self, box: np.ndarray) -> np.ndarray: + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)).reshape((-1, 1, 2)) + return expanded diff --git a/ch_ppocr_rec/__init__.py b/ch_ppocr_rec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37eafdc7b001b5e2e6db66257c14b435d12a94fc --- /dev/null +++ b/ch_ppocr_rec/__init__.py @@ -0,0 +1,4 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .text_recognize import TextRecognizer diff --git a/ch_ppocr_rec/__pycache__/__init__.cpython-310.pyc b/ch_ppocr_rec/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30c70131f52294b7aa832b87c4a96ec02b86b88b Binary files /dev/null and b/ch_ppocr_rec/__pycache__/__init__.cpython-310.pyc differ diff --git a/ch_ppocr_rec/__pycache__/__init__.cpython-311.pyc b/ch_ppocr_rec/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca9187f95383b88c4774c174fd985d25a1e3556 Binary files /dev/null and b/ch_ppocr_rec/__pycache__/__init__.cpython-311.pyc differ diff --git a/ch_ppocr_rec/__pycache__/__init__.cpython-312.pyc b/ch_ppocr_rec/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28a96ce6d4b79795d28732304c7f9076a18f25d0 Binary files /dev/null and b/ch_ppocr_rec/__pycache__/__init__.cpython-312.pyc differ diff --git a/ch_ppocr_rec/__pycache__/__init__.cpython-313.pyc b/ch_ppocr_rec/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d02b5472c4342af8f917f0a402cc823cd021635 Binary files /dev/null and b/ch_ppocr_rec/__pycache__/__init__.cpython-313.pyc differ diff --git a/ch_ppocr_rec/__pycache__/text_recognize.cpython-310.pyc b/ch_ppocr_rec/__pycache__/text_recognize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3d8a1152049f8002f40dadbf5bdfb95006d4415 Binary files /dev/null and b/ch_ppocr_rec/__pycache__/text_recognize.cpython-310.pyc differ diff --git a/ch_ppocr_rec/__pycache__/text_recognize.cpython-311.pyc b/ch_ppocr_rec/__pycache__/text_recognize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca9c7b5882a4776113bc875442fbecb9848e7ca Binary files /dev/null and b/ch_ppocr_rec/__pycache__/text_recognize.cpython-311.pyc differ diff --git a/ch_ppocr_rec/__pycache__/text_recognize.cpython-312.pyc b/ch_ppocr_rec/__pycache__/text_recognize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77bb5b8fab209a0df42f4bfe4be0ca7608c775a3 Binary files /dev/null and b/ch_ppocr_rec/__pycache__/text_recognize.cpython-312.pyc differ diff --git a/ch_ppocr_rec/__pycache__/text_recognize.cpython-313.pyc b/ch_ppocr_rec/__pycache__/text_recognize.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed631339e175419e3f74b5cd75c9e71088537ffb Binary files /dev/null and b/ch_ppocr_rec/__pycache__/text_recognize.cpython-313.pyc differ diff --git a/ch_ppocr_rec/__pycache__/utils.cpython-310.pyc b/ch_ppocr_rec/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d82d4da1b48038a439340b514c2184efb68873 Binary files /dev/null and b/ch_ppocr_rec/__pycache__/utils.cpython-310.pyc differ diff --git a/ch_ppocr_rec/__pycache__/utils.cpython-311.pyc b/ch_ppocr_rec/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46c998ef470ad81fffb3c677da38ea8aa20fdd8c Binary files /dev/null and b/ch_ppocr_rec/__pycache__/utils.cpython-311.pyc differ diff --git a/ch_ppocr_rec/__pycache__/utils.cpython-312.pyc b/ch_ppocr_rec/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d3a2820f07c67aee6b1b8f385d23861ab78b737 Binary files /dev/null and b/ch_ppocr_rec/__pycache__/utils.cpython-312.pyc differ diff --git a/ch_ppocr_rec/__pycache__/utils.cpython-313.pyc b/ch_ppocr_rec/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8038ebf7d283632f9a34ad85a2a51dccd401afff Binary files /dev/null and b/ch_ppocr_rec/__pycache__/utils.cpython-313.pyc differ diff --git a/ch_ppocr_rec/text_recognize.py b/ch_ppocr_rec/text_recognize.py new file mode 100644 index 0000000000000000000000000000000000000000..e823ea65515e9cb3ef27c39f2761e462ea952708 --- /dev/null +++ b/ch_ppocr_rec/text_recognize.py @@ -0,0 +1,130 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import math +import time +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np + +from rapidocr_onnxruntime.utils import OrtInferSession, read_yaml + +from .utils import CTCLabelDecode + + +class TextRecognizer: + def __init__(self, config: Dict[str, Any]): + self.session = OrtInferSession(config) + + character = None + if self.session.have_key(): + character = self.session.get_character_list() + + character_path = config.get("rec_keys_path", None) + self.postprocess_op = CTCLabelDecode( + character=character, character_path=character_path + ) + + self.rec_batch_num = config["rec_batch_num"] + self.rec_image_shape = config["rec_img_shape"] + + def __call__( + self, + img_list: Union[np.ndarray, List[np.ndarray]], + return_word_box: bool = False, + ) -> Tuple[List[Tuple[str, float]], float]: + if isinstance(img_list, np.ndarray): + img_list = [img_list] + + # Calculate the aspect ratio of all text bars + width_list = [img.shape[1] / float(img.shape[0]) for img in img_list] + + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + + img_num = len(img_list) + rec_res = [("", 0.0)] * img_num + + batch_num = self.rec_batch_num + elapse = 0 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + + # Parameter Alignment for PaddleOCR + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + wh_ratio_list = [] + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + wh_ratio_list.append(wh_ratio) + + norm_img_batch = [] + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) + norm_img_batch.append(norm_img[np.newaxis, :]) + norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) + + starttime = time.time() + preds = self.session(norm_img_batch)[0] + rec_result = self.postprocess_op( + preds, + return_word_box, + wh_ratio_list=wh_ratio_list, + max_wh_ratio=max_wh_ratio, + ) + + for rno, one_res in enumerate(rec_result): + rec_res[indices[beg_img_no + rno]] = one_res + elapse += time.time() - starttime + return rec_res, elapse + + def resize_norm_img(self, img: np.ndarray, max_wh_ratio: float) -> np.ndarray: + img_channel, img_height, img_width = self.rec_image_shape + assert img_channel == img.shape[2] + + img_width = int(img_height * max_wh_ratio) + + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(img_height * ratio) > img_width: + resized_w = img_width + else: + resized_w = int(math.ceil(img_height * ratio)) + + resized_image = cv2.resize(img, (resized_w, img_height)) + resized_image = resized_image.astype("float32") + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + + padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image_path", type=str, help="image_dir|image_path") + parser.add_argument("--config_path", type=str, default="config.yaml") + args = parser.parse_args() + + config = read_yaml(args.config_path) + text_recognizer = TextRecognizer(config) + + img = cv2.imread(args.image_path) + rec_res, predict_time = text_recognizer(img) + print(f"rec result: {rec_res}\t cost: {predict_time}s") diff --git a/ch_ppocr_rec/utils.py b/ch_ppocr_rec/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..224b2f879207c74125d1389822acfdb1fe0f507e --- /dev/null +++ b/ch_ppocr_rec/utils.py @@ -0,0 +1,189 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np + + +class CTCLabelDecode: + def __init__( + self, + character: Optional[List[str]] = None, + character_path: Union[str, Path, None] = None, + ): + self.character = self.get_character(character, character_path) + self.dict = {char: i for i, char in enumerate(self.character)} + + def __call__( + self, preds: np.ndarray, return_word_box: bool = False, **kwargs + ) -> List[Tuple[str, float]]: + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode( + preds_idx, preds_prob, return_word_box, is_remove_duplicate=True + ) + if return_word_box: + for rec_idx, rec in enumerate(text): + wh_ratio = kwargs["wh_ratio_list"][rec_idx] + max_wh_ratio = kwargs["max_wh_ratio"] + rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio) + return text + + def get_character( + self, + character: Optional[List[str]] = None, + character_path: Union[str, Path, None] = None, + ) -> List[str]: + if character is None and character_path is None: + raise ValueError("character must not be None") + + character_list = None + if character: + character_list = character + + if character_path: + character_list = self.read_character_file(character_path) + + if character_list is None: + raise ValueError("character must not be None") + + character_list = self.insert_special_char( + character_list, " ", len(character_list) + ) + character_list = self.insert_special_char(character_list, "blank", 0) + return character_list + + @staticmethod + def read_character_file(character_path: Union[str, Path]) -> List[str]: + character_list = [] + with open(character_path, "rb") as f: + lines = f.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_list.append(line) + return character_list + + @staticmethod + def insert_special_char( + character_list: List[str], special_char: str, loc: int = -1 + ) -> List[str]: + character_list.insert(loc, special_char) + return character_list + + def decode( + self, + text_index: np.ndarray, + text_prob: Optional[np.ndarray] = None, + return_word_box: bool = False, + is_remove_duplicate: bool = False, + ) -> List[Tuple[str, float]]: + """convert text-index into text-label.""" + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] + + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + if text_prob is not None: + conf_list = np.array(text_prob[batch_idx][selection]).tolist() + else: + conf_list = [1] * len(selection) + + if len(conf_list) == 0: + conf_list = [0] + + char_list = [ + self.character[text_id] for text_id in text_index[batch_idx][selection] + ] + text = "".join(char_list) + if return_word_box: + word_list, word_col_list, state_list = self.get_word_info( + text, selection + ) + result_list.append( + ( + text, + np.mean(conf_list).tolist(), + [ + len(text_index[batch_idx]), + word_list, + word_col_list, + state_list, + conf_list, + ], + ) + ) + else: + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + @staticmethod + def get_word_info( + text: str, selection: np.ndarray + ) -> Tuple[List[List[str]], List[List[int]], List[str]]: + """ + Group the decoded characters and record the corresponding decoded positions. + from https://github.com/PaddlePaddle/PaddleOCR/blob/fbba2178d7093f1dffca65a5b963ec277f1a6125/ppocr/postprocess/rec_postprocess.py#L70 + + Args: + text: the decoded text + selection: the bool array that identifies which columns of features are decoded as non-separated characters + Returns: + word_list: list of the grouped words + word_col_list: list of decoding positions corresponding to each character in the grouped word + state_list: list of marker to identify the type of grouping words, including two types of grouping words: + - 'cn': continous chinese characters (e.g., 你好啊) + - 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16) + """ + state = None + word_content = [] + word_col_content = [] + word_list = [] + word_col_list = [] + state_list = [] + valid_col = np.where(selection)[0] + col_width = np.zeros(valid_col.shape) + if len(valid_col) > 0: + col_width[1:] = valid_col[1:] - valid_col[:-1] + col_width[0] = min( + 3 if "\u4e00" <= text[0] <= "\u9fff" else 2, int(valid_col[0]) + ) + + for c_i, char in enumerate(text): + if "\u4e00" <= char <= "\u9fff": + c_state = "cn" + else: + c_state = "en&num" + + if state is None: + state = c_state + + if state != c_state or col_width[c_i] > 4: + if len(word_content) != 0: + word_list.append(word_content) + word_col_list.append(word_col_content) + state_list.append(state) + word_content = [] + word_col_content = [] + state = c_state + + word_content.append(char) + word_col_content.append(int(valid_col[c_i])) + + if len(word_content) != 0: + word_list.append(word_content) + word_col_list.append(word_col_content) + state_list.append(state) + + return word_list, word_col_list, state_list + + @staticmethod + def get_ignored_tokens() -> List[int]: + return [0] # for ctc blank diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebccdec5a0df49e58175900a604f9b13ca45ed8c --- /dev/null +++ b/compose.yaml @@ -0,0 +1,24 @@ +services: + python-app: + build: + context: . + dockerfile: Dockerfile + container_name: python-app + restart: unless-stopped + init: true + # env_file: ./.env # Uncomment if .env file exists + ports: + - "7860:7860" # Gradio default port, exposed in Dockerfile + # If you need to persist data, add volumes here + # volumes: + # - ./models:/app/models # Uncomment if you want to persist models + # - ./results:/app/results # Uncomment if you want to persist results + # Add more volumes as needed for persistent data + # networks: [app-net] # Uncomment if you add more services/networks + +# No external services (databases, caches, etc.) detected in the project structure or Dockerfile. +# If you add such dependencies, define them here and add to networks. + +# networks: +# app-net: +# driver: bridge diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d00732a0b030716bc14596dc7e29660e7cba97f --- /dev/null +++ b/config.yaml @@ -0,0 +1,47 @@ +Global: + text_score: 0.5 + use_det: true + use_cls: true + use_rec: true + print_verbose: false + min_height: 30 + width_height_ratio: 8 + max_side_len: 2000 + min_side_len: 30 + return_word_box: false + + intra_op_num_threads: &intra_nums -1 + inter_op_num_threads: &inter_nums -1 + +Det: + intra_op_num_threads: *intra_nums + inter_op_num_threads: *inter_nums + + use_cuda: false + use_dml: false + + model_path: models/ch_PP-OCRv4_det_infer.onnx + + limit_side_len: 736 + limit_type: min + std: [ 0.5, 0.5, 0.5 ] + mean: [ 0.5, 0.5, 0.5 ] + + thresh: 0.3 + box_thresh: 0.5 + max_candidates: 1000 + unclip_ratio: 1.6 + use_dilation: true + score_mode: fast + +Rec: + intra_op_num_threads: *intra_nums + inter_op_num_threads: *inter_nums + + use_cuda: false + use_dml: false + + model_path: models/ch_PP-OCRv4_rec_infer.onnx + + rec_img_shape: [3, 48, 320] + rec_batch_num: 6 diff --git a/demo_code.py b/demo_code.py new file mode 100644 index 0000000000000000000000000000000000000000..d0687c8be0c45f174cb7f480f4d3289f547bebc0 --- /dev/null +++ b/demo_code.py @@ -0,0 +1,109 @@ +import os +import cv2 +import numpy as np +from pdf2image import convert_from_path + +from main import RapidOCR +ocr_engine = RapidOCR() + +dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本' + +from image_enhancement import enhance_image + +def crop_dynamic(image_rgb): + """ + Dynamically crop the blank regions (white or black) surrounding the object. + + Parameters: + image_rgb (numpy.ndarray): Input image in RGB format. + + Returns: + cropped_rgb (numpy.ndarray): Cropped RGB image. + bbox (tuple): Bounding box of the cropped region (x, y, w, h). + """ + # Convert to grayscale for easier processing + gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) + + # Find non-blank rows and columns based on pixel intensity + row_mask = np.any(gray < 240, axis=1) # Detect rows with pixel intensity below the white threshold + col_mask = np.any(gray < 240, axis=0) # Detect columns with pixel intensity below the white threshold + + # Adjust logic for black regions by combining white and black detection + row_mask = row_mask | np.any(gray > 10, axis=1) # Include black regions + col_mask = col_mask | np.any(gray > 10, axis=0) # Include black regions + + # Find bounding box indices + y_min, y_max = np.where(row_mask)[0][[0, -1]] + x_min, x_max = np.where(col_mask)[0][[0, -1]] + + # Crop the region + cropped_rgb = image_rgb[y_min:y_max+1, x_min:x_max+1] + return cropped_rgb, (x_min, y_min, x_max - x_min, y_max - y_min) + +list_pdf = [] +for root, dirs, files in os.walk(dataPath): + for file in files: + if file.endswith('.pdf'): + pdf_f = os.path.join(root, file) + assert os.path.exists(pdf_f) + list_pdf.append(pdf_f) +sorted(list_pdf) + +for idx, pdf_f in enumerate(list_pdf): + bs_name = os.path.basename(pdf_f) + bs_name_0 = os.path.splitext(bs_name)[0] + + +# images = convert_from_path(pdf_f, dpi=900) + images = convert_from_path(pdf_f, dpi=500, first_page=1, last_page=3) + for i, image in enumerate(images): + #brightness = ImageEnhance.Brightness(image).enhance(1.5) + #contrast = ImageEnhance.Contrast(brightness).enhance(1.8) + #sharpness = ImageEnhance.Sharpness(contrast).enhance(2.0) + #sharpness.save("{i}_"+bs_name) + img = np.array(image) + #img = enhance_image(img) +# img, bbox = crop_dynamic(img) + + parameters = {} + parameters['local_contrast'] = 1.5 # 1.5x increase in details + parameters['mid_tones'] = 0.5 + parameters['tonal_width'] = 0.5 + parameters['areas_dark'] = 0.7 # 70% improvement in dark areas + parameters['areas_bright'] = 0.5 # 50% improvement in bright areas + parameters['saturation_degree'] = 1.2 # 1.2x increase in color saturation + parameters['brightness'] = 0.1 # slight increase in brightness + parameters['preserve_tones'] = True + parameters['color_correction'] = False + img = enhance_image(image, parameters, verbose=False) + + print(img.shape) + enhanced_img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Save in OpenCV-compatible format + cv2.imwrite(f'{i + 1}_{bs_name_0}.jpg', enhanced_img_bgr) + print(bs_name_0, i ) + rotation_attempts = 0 # Track rotation count + + while rotation_attempts < 4: # Rotate at most 4 times (90°, 180°, 270°, and back to original orientation) + result, _ = ocr_engine(img, use_det=True, use_cls=False, use_rec=True) + detected = False # Flag to check detection status + if result: + test_list = [r[1] for r in result] + + for j in range(len(test_list) - 1): # Loop up to the second-to-last row + count1 = test_list[j].count("<") + count2 = test_list[j + 1].count("<") + if count1 > 2 and count2 > 2: + print(bs_name_0) + print(f"Consecutive rows with '<' more than 2 times each:") + print(f"Row 1: {test_list[j]} (Occurrences: {count1})") + print(f"Row 2: {test_list[j + 1]} (Occurrences: {count2})") + detected = True + break + + if detected: + break # Stop further rotation since rows are detected + + # Rotate the image by 90 degrees + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + rotation_attempts += 1 + diff --git a/demo_v2.py b/demo_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f4cacc6edd1b7998a8c15422bd3782e46c95f15f --- /dev/null +++ b/demo_v2.py @@ -0,0 +1,137 @@ +import os +import cv2 +import numpy as np +from pdf2image import convert_from_path + +from main import RapidOCR +ocr_engine = RapidOCR() + +dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本' + +from image_enhancement import enhance_image + +list_pdf = [] +for root, dirs, files in os.walk(dataPath): + for file in files: + if file.endswith('.pdf'): + pdf_f = os.path.join(root, file) + assert os.path.exists(pdf_f) + list_pdf.append(pdf_f) +sorted(list_pdf) + +def adaptive_threshold_to_rgb(image_rgb): + """ + Apply adaptive thresholding on the L channel of LAB color space + and reconstruct the thresholded image as RGB. + + Parameters: + image_rgb (numpy.ndarray): Input RGB image. + + Returns: + thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel. + """ + # Convert RGB to LAB color space + image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB) + + # Split LAB channels + l_channel, a_channel, b_channel = cv2.split(image_lab) + + # Apply adaptive thresholding to the L channel + thresholded_l = cv2.adaptiveThreshold( + l_channel, + maxValue=255, + adaptiveMethod=cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # or ADAPTIVE_THRESH_MEAN_C + thresholdType=cv2.THRESH_BINARY, + blockSize=11, + C=2 + ) + + # Merge thresholded L channel back with original A and B channels + updated_lab = cv2.merge((thresholded_l, a_channel, b_channel)) + + # Convert LAB back to RGB + thresholded_rgb = cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB) + + return thresholded_rgb + +for idx, pdf_f in enumerate(list_pdf): + bs_name = os.path.basename(pdf_f) + bs_name_0 = os.path.splitext(bs_name)[0] + +# images = convert_from_path(pdf_f, dpi=900) + images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3) + for i, image in enumerate(images): + img = np.array(image) + print(img.shape) + parameters = {} + parameters['local_contrast'] = 1.2 # 1.2x increase in details + parameters['mid_tones'] = 0.5 # middle of range + parameters['tonal_width'] = 0.5 # middle of range + parameters['areas_dark'] = 0.7 # 70% improvement in dark areas + parameters['areas_bright'] = 0.5 # 50% improvement in bright areas + parameters['brightness'] = 0.1 # slight increase in overall brightness + parameters['saturation_degree'] = 1.2 # 1.2x increase in color saturation + parameters['preserve_tones'] = True + parameters['color_correction'] = True + img = enhance_image(img, parameters, verbose=False) + #print(img.shape, img.dtype, img.max(), img.min()) + img = np.uint8(img*255.) + + enhanced_img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Save in OpenCV-compatible format + cv2.imwrite(f'{bs_name_0}_{i + 1}.jpg', enhanced_img_bgr) + print(bs_name_0, i ) + rotation_attempts = 0 # Track rotation count + + while rotation_attempts < 4: # Rotate at most 4 times (90°, 180°, 270°, and back to original orientation) + result, _ = ocr_engine(img, use_det=True, use_cls=False, use_rec=True) + detected = False # Flag to check detection status + if result: + test_list = [r[1] for r in result] + #print(test_list[-5:]) + + for j in range(len(test_list) - 1): # Loop up to the second-to-last row + count1 = test_list[j].count("<") + count2 = test_list[j + 1].count("<") + if count1 > 1 and count2 > 1: + print(bs_name_0) + print(f"Consecutive rows with '<' more than 2 times each:") + print(f"Row 1: {test_list[j]} (Occurrences: {count1})") + print(f"Row 2: {test_list[j + 1]} (Occurrences: {count2})") + detected = True + break + + if detected: + break # Stop further rotation since rows are detected + + # Rotate the image by 90 degrees + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + rotation_attempts += 1 + + if not detected: + img = adaptive_threshold_to_rgb(img) + rotation_attempts = 0 # Track rotation count + + while rotation_attempts < 4: # Rotate at most 4 times (90°, 180°, 270°, and back to original orientation) + result, _ = ocr_engine(img, use_det=True, use_cls=False, use_rec=True) + detected = False # Flag to check detection status + if result: + test_list = [r[1] for r in result] + #print(test_list[-5:]) + + for j in range(len(test_list) - 1): # Loop up to the second-to-last row + count1 = test_list[j].count("<") + count2 = test_list[j + 1].count("<") + if count1 > 1 and count2 > 1: + print(bs_name_0) + print(f"Consecutive rows with '<' more than 2 times each:") + print(f"Row 1: {test_list[j]} (Occurrences: {count1})") + print(f"Row 2: {test_list[j + 1]} (Occurrences: {count2})") + detected = True + break + + if detected: + break # Stop further rotation since rows are detected + + # Rotate the image by 90 degrees + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + rotation_attempts += 1 diff --git a/demo_v3.py b/demo_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c7b7f8ac063962a0d6fa88345e6f2f4335b1e3 --- /dev/null +++ b/demo_v3.py @@ -0,0 +1,167 @@ +import os +import cv2 +import numpy as np +from pdf2image import convert_from_path + +from main import RapidOCR +from image_enhancement import enhance_image + +# Initialize OCR engine once. +ocr_engine = RapidOCR() + + +def adaptive_threshold_to_rgb(image_rgb): + """ + Apply adaptive thresholding on the L channel of LAB color space + and reconstruct the thresholded image as RGB. + + Parameters: + image_rgb (numpy.ndarray): Input RGB image. + + Returns: + thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel. + """ + # Convert RGB to LAB color space and split channels. + image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB) + l_channel, a_channel, b_channel = cv2.split(image_lab) + + # Apply adaptive thresholding to the L channel. + thresholded_l = cv2.adaptiveThreshold( + l_channel, + maxValue=255, + adaptiveMethod=cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # or ADAPTIVE_THRESH_MEAN_C + thresholdType=cv2.THRESH_BINARY, + blockSize=11, + C=2 + ) + + # Merge the thresholded L channel back with A and B channels. + updated_lab = cv2.merge((thresholded_l, a_channel, b_channel)) + thresholded_rgb = cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB) + return thresholded_rgb + + +def ocr_detect(image, ocr_engine): + """ + Run OCR on the image using the provided ocr_engine and check if consecutive + rows containing the '<' character are detected. + + Parameters: + image (numpy.ndarray): Input image. + ocr_engine: The OCR engine instance. + + Returns: + detected (bool): True if the desired pattern is detected, False otherwise. + """ + result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True) + if result: + test_list = [r[1] for r in result] + for j in range(len(test_list) - 1): + count1 = test_list[j].count("<") + count2 = test_list[j + 1].count("<") + if count1 > 1 and count2 > 1: + return True + return False + + +def rotate_until_detect(image, ocr_engine, max_attempts=4): + """ + Rotate the image 90° clockwise, up to max_attempts times, until the OCR + conveys the expected result. + + Parameters: + image (numpy.ndarray): Input image. + ocr_engine: The OCR engine instance. + max_attempts (int): Maximum number of rotations. + + Returns: + image (numpy.ndarray): Rotated image with detection (or final rotation if undetected). + detected (bool): Whether the expected OCR pattern was detected. + """ + attempt = 0 + detected = False + while attempt < max_attempts: + if ocr_detect(image, ocr_engine): + detected = True + break + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + attempt += 1 + return image, detected + + +def process_pdf(pdf_f, ocr_engine, enhance_params): + """ + Process a single PDF file by converting pages, enhancing images, + running OCR with rotation, and using adaptive thresholding as a fallback. + + Parameters: + pdf_f (str): PDF file path. + ocr_engine: The OCR engine instance. + enhance_params (dict): Parameters for the image enhancement. + """ + # Convert specified pages of PDF into images. + images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3) + bs_name = os.path.basename(pdf_f) + bs_name_0 = os.path.splitext(bs_name)[0] + + for i, pil_image in enumerate(images): + # Convert PIL image to a NumPy array. + img = np.array(pil_image) + print("Original image shape:", img.shape) + + # Enhance the image. + img = enhance_image(img, enhance_params, verbose=False) + img = np.uint8(img * 255.) + + # Save the enhanced image as a reference. + enhanced_img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + output_filename = f'{bs_name_0}_{i + 1}.jpg' + cv2.imwrite(output_filename, enhanced_img_bgr) + print(f"Saved enhanced image: {output_filename}") + + # First: Try OCR on the enhanced image with rotation. + processed_img, detected = rotate_until_detect(img, ocr_engine) + if detected: + print(f"OCR success on {output_filename} with enhanced image rotation.") + else: + # Second: Apply adaptive thresholding and re-run OCR with rotation. + print(f"No OCR detection from enhanced image. Applying adaptive thresholding for {output_filename}.") + adaptive_img = adaptive_threshold_to_rgb(img) + processed_img, detected = rotate_until_detect(adaptive_img, ocr_engine) + if detected: + print(f"OCR success on {output_filename} with adaptive thresholding and rotation.") + else: + print(f"OCR detection failed for {output_filename} after fallback.") + + +def main(): + # Set the data path and gather list of PDF files. + dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本' + list_pdf = [ + os.path.join(root, file) + for root, _, files in os.walk(dataPath) + for file in files if file.endswith('.pdf') + ] + # Optionally, sort the list. + # list_pdf = sorted(list_pdf) + + # Define image enhancement parameters (applied to every image). + enhance_params = { + 'local_contrast': 1.2, # 1.2x increase in details + 'mid_tones': 0.5, # middle of range + 'tonal_width': 0.5, # middle of range + 'areas_dark': 0.7, # 70% improvement in dark areas + 'areas_bright': 0.5, # 50% improvement in bright areas + 'brightness': 0.1, # slight increase in overall brightness + 'saturation_degree': 1.2, # 1.2x increase in color saturation + 'preserve_tones': True, + 'color_correction': True, + } + + # Process each PDF. + for pdf_f in list_pdf: + process_pdf(pdf_f, ocr_engine, enhance_params) + + +if __name__ == '__main__': + main() diff --git a/demo_v4.py b/demo_v4.py new file mode 100644 index 0000000000000000000000000000000000000000..06b060bdd294c02d4a5efc14d5f5c683b1e61d0a --- /dev/null +++ b/demo_v4.py @@ -0,0 +1,199 @@ +import os +import cv2 +import numpy as np +from pdf2image import convert_from_path + +from main import RapidOCR +from image_enhancement import enhance_image + +# Initialize OCR engine once. +ocr_engine = RapidOCR() + + +def adaptive_threshold_to_rgb(image_rgb): + """ + Convert an RGB image to LAB, apply adaptive thresholding only on the L channel, + then convert back to RGB. + + Parameters: + image_rgb (numpy.ndarray): Input RGB image. + + Returns: + thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel. + """ + # Convert RGB to LAB color space. + image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB) + l_channel, a_channel, b_channel = cv2.split(image_lab) + + # Adaptive thresholding on the L channel. + thresholded_l = cv2.adaptiveThreshold( + l_channel, + maxValue=255, + adaptiveMethod=cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # or ADAPTIVE_THRESH_MEAN_C + thresholdType=cv2.THRESH_BINARY, + blockSize=11, + C=2 + ) + + # Merge the thresholded L channel with original A and B, then convert back to RGB. + updated_lab = cv2.merge((thresholded_l, a_channel, b_channel)) + thresholded_rgb = cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB) + return thresholded_rgb + + +def ocr_detect(image, ocr_engine): + """ + Run OCR on the image and check for two consecutive rows that contain the '<' character. + + Parameters: + image (numpy.ndarray): Input image. + ocr_engine: OCR engine instance. + + Returns: + detected (bool): True if found, else False. + row1 (str): The first detected row with '<'. + row2 (str): The second detected row with '<'. + """ + result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True) + if result: + # Get recognized strings + test_list = [r[1] for r in result] + for j in range(len(test_list) - 1): + count1 = test_list[j].count("<") + count2 = test_list[j + 1].count("<") + if count1 > 1 and count2 > 1: + return True, test_list[j], test_list[j + 1] + return False, None, None + + +def rotate_until_detect(image, ocr_engine, max_attempts=4): + """ + Rotate the image 90° clockwise up to max_attempts times until OCR returns + two consecutive rows that meet the specified criteria. + + Parameters: + image (numpy.ndarray): Input image. + ocr_engine: OCR engine instance. + max_attempts (int): Maximum number of rotations. + + Returns: + image (numpy.ndarray): Final rotated image. + detected (bool): True if OCR detection succeeded. + row1, row2 (str, str): The two detected rows (if found; otherwise None). + """ + attempt = 0 + detected = False + row1, row2 = None, None + while attempt < max_attempts: + detected, row1, row2 = ocr_detect(image, ocr_engine) + if detected: + break + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + attempt += 1 + return image, detected, row1, row2 + + +def process_pdf(pdf_f, ocr_engine, enhance_params, save_images=False): + """ + Process a single PDF file by converting a range of pages, enhancing images, + and attempting OCR detections. A PDF is considered successful if at least one page + yields two consecutive rows detected. Returns the (row1, row2) pair on success. + + Parameters: + pdf_f (str): File path of the PDF. + ocr_engine: The OCR engine instance. + enhance_params (dict): Parameters for image enhancement. + save_images (bool): If True, save intermediate enhanced images (default: False). + + Returns: + (pdf_success, detected_rows): + pdf_success (bool): True if detection succeeded in any page. + detected_rows (tuple): (row1, row2) from the successful page, or (None, None) if not. + """ + images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3) + bs_name = os.path.basename(pdf_f) + bs_name_0 = os.path.splitext(bs_name)[0] + + pdf_success = False + detected_rows = (None, None) + + for i, pil_image in enumerate(images): + # Convert the PIL image to a NumPy array. + img = np.array(pil_image) +# print(f"Processing page {i + 1} of {bs_name}") + + # Enhance the image. + img = enhance_image(img, enhance_params, verbose=False) + img = np.uint8(img * 255.) + + # Optionally save the enhanced image. + if save_images: + enhanced_img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + cv2.imwrite(f'{bs_name_0}_{i + 1}.jpg', enhanced_img_bgr) + + # Attempt OCR on the enhanced image (with rotations). + proc_img, detected, row1, row2 = rotate_until_detect(img, ocr_engine) + if detected: +# print(f"OCR detection succeeded on page {i + 1} of {bs_name}") + pdf_success = True + detected_rows = (row1, row2) + break + else: + # Fallback: perform adaptive thresholding then try OCR. +# print(f"No detection on page {i + 1} of {bs_name}. Trying adaptive thresholding.") + adaptive_img = adaptive_threshold_to_rgb(img) + proc_img, detected, row1, row2 = rotate_until_detect(adaptive_img, ocr_engine) + if detected: +# print(f"OCR detection (via adaptive thresholding) succeeded on page {i + 1} of {bs_name}") + pdf_success = True + detected_rows = (row1, row2) + break + else: + print(f"OCR detection failed on page {i + 1} of {bs_name}.") + + if pdf_success: + print(f"PDF file {bs_name_0} processed successfully.") + else: + print(f"PDF file {bs_name_0} did NOT yield a successful OCR detection.") + + return pdf_success, detected_rows + + +def main(): + # Define the folder containing PDFs. + dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本' + list_pdf = [ + os.path.join(root, file) + for root, _, files in os.walk(dataPath) + for file in files if file.endswith('.pdf') + ] + + # Define image enhancement parameters. + enhance_params = { + 'local_contrast': 1.2, # 1.2x increase in detail + 'mid_tones': 0.5, # middle range + 'tonal_width': 0.5, # middle range + 'areas_dark': 0.7, # 70% improvement in dark areas + 'areas_bright': 0.5, # 50% improvement in bright areas + 'brightness': 0.1, # slight increase in overall brightness + 'saturation_degree': 1.2, # 1.2x increase in color saturation + 'preserve_tones': True, + 'color_correction': True, + } + + # Process each PDF and collect results. + for pdf_f in list_pdf: + print("") + print(f"--- Processing PDF: {pdf_f} ---") + success, detected_rows = process_pdf(pdf_f, ocr_engine, enhance_params, save_images=False) + + if success: +# print("\nSuccess in detecting two rows for this PDF:") + print("PDF:", os.path.basename(pdf_f)) + print("Row 1:", detected_rows[0]) + print("Row 2:", detected_rows[1]) + else: + print("No successful detection for this PDF.") + +if __name__ == '__main__': + main() diff --git a/demo_v5.py b/demo_v5.py new file mode 100644 index 0000000000000000000000000000000000000000..e95af9334aa08b74372e7087818f9cc20c0b9b04 --- /dev/null +++ b/demo_v5.py @@ -0,0 +1,195 @@ +import os +import cv2 +import numpy as np +from pdf2image import convert_from_path + +from main import RapidOCR +from image_enhancement import enhance_image +import gradio as gr +import time +# Initialize OCR engine once. +ocr_engine = RapidOCR() + + +def adaptive_threshold_to_rgb(image_rgb): + """ + Convert an RGB image to LAB, apply adaptive thresholding only on the L channel, + then convert back to RGB. + + Parameters: + image_rgb (numpy.ndarray): Input RGB image. + + Returns: + thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel. + """ + image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB) + l_channel, a_channel, b_channel = cv2.split(image_lab) + thresholded_l = cv2.adaptiveThreshold( + l_channel, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2 + ) + updated_lab = cv2.merge((thresholded_l, a_channel, b_channel)) + return cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB) + + +def ocr_detect(image, ocr_engine): + """ + Run OCR on the image and check for two consecutive rows that contain the '<' character. + + Parameters: + image (numpy.ndarray): Input image. + ocr_engine: OCR engine instance. + + Returns: + detected (bool): True if found, else False. + row1 (str): The first detected row with '<'. + row2 (str): The second detected row with '<'. + """ + result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True) + if result: + test_list = [r[1] for r in result] + for j in range(len(test_list) - 1): + count1 = test_list[j].count("<") + count2 = test_list[j + 1].count("<") + if count1 > 1 and count2 > 1: + return True, test_list[j], test_list[j + 1] + return False, None, None + + +def rotate_until_detect(image, ocr_engine, max_attempts=4): + """ + Rotate the image 90° clockwise up to max_attempts times until OCR returns + two consecutive rows that meet the specified criteria. + + Parameters: + image (numpy.ndarray): Input image. + ocr_engine: OCR engine instance. + max_attempts (int): Maximum number of rotations. + + Returns: + image (numpy.ndarray): Final rotated image. + detected (bool): True if OCR detection succeeded. + row1, row2 (str, str): The two detected rows (if found; otherwise None). + """ + for attempt in range(max_attempts): + detected, row1, row2 = ocr_detect(image, ocr_engine) + if detected: + return image, True, row1, row2 + image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) + return image, False, None, None + + +def process_pdf(pdf_f, ocr_engine, enhance_params): + """ + Process a single PDF file by converting pages, enhancing images, + and attempting OCR detections. A PDF is considered successful if at least one page + yields two consecutive rows detected. Returns the (row1, row2) pair on success. + + Parameters: + pdf_f (str): File path of the PDF. + ocr_engine: The OCR engine instance. + enhance_params (dict): Parameters for image enhancement. + + Returns: + (pdf_success, detected_rows): + pdf_success (bool): True if detection succeeded in any page. + detected_rows (tuple): (row1, row2) from the successful page, or (None, None) if not. + """ + images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3) + pdf_success = False + detected_rows = (None, None) + + for pil_image in images: + img = np.array(pil_image) + img = enhance_image(img, enhance_params, verbose=False) + img = np.uint8(img * 255.) + _, detected, row1, row2 = rotate_until_detect(img, ocr_engine) + if detected: + pdf_success = True + detected_rows = (row1, row2) + break + else: + adaptive_img = adaptive_threshold_to_rgb(img) + _, detected, row1, row2 = rotate_until_detect(adaptive_img, ocr_engine) + if detected: + pdf_success = True + detected_rows = (row1, row2) + break + + return pdf_success, detected_rows + + +# def main(): +# # Define the folder containing PDFs. +# # dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本' +# dataPath = 'C:/Users/Duy/Downloads/passport/' +# result_file = os.path.join(dataPath,'results.txt') + +# list_pdf = [ +# os.path.join(root, file) +# for root, _, files in os.walk(dataPath) +# for file in files if file.endswith('.pdf') +# ] + +# enhance_params = { +# 'local_contrast': 1.2, 'mid_tones': 0.5, 'tonal_width': 0.5, 'areas_dark': 0.7, +# 'areas_bright': 0.5, 'brightness': 0.1, 'saturation_degree': 1.2, +# 'preserve_tones': True, 'color_correction': True, +# } + +# # Open the result file for writing +# with open(result_file, 'w') as f: +# for pdf_f in list_pdf: +# pdf_name = os.path.basename(pdf_f) +# print(f"Processing {pdf_f}...") +# success, detected_rows = process_pdf(pdf_f, ocr_engine, enhance_params) + +# if success: +# f.write(f"--- PDF: {pdf_name} ---\n") +# f.write("Success\n") +# f.write(f"Row 1: {detected_rows[0]}\n") +# f.write(f"Row 2: {detected_rows[1]}\n\n") +# print(f"Success: {pdf_name}") +# print("Row 1:", detected_rows[0]) +# print("Row 2:", detected_rows[1]) +# else: +# f.write(f"--- PDF: {pdf_name} ---\n") +# f.write("No successful detection\n\n") +# print(f"No detection: {pdf_name}") + +# print(f"Results written to {result_file}") + +def handle_file_upload(file_bytes): + enhance_params = { + 'local_contrast': 1.2, 'mid_tones': 0.5, 'tonal_width': 0.5, 'areas_dark': 0.7, + 'areas_bright': 0.5, 'brightness': 0.1, 'saturation_degree': 1.2, + 'preserve_tones': True, 'color_correction': True, + } + # print(f"Processing uploaded file: {file_path}") + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # 2. Tạo thư mục tmp nếu chưa tồn tại + tmp_dir = os.path.join(current_dir, "tmp") + os.makedirs(tmp_dir, exist_ok=True) + timestamp = int(time.time()) + save_path = os.path.join(tmp_dir, f"uploaded_{timestamp}.pdf") + # 4. Save binary thành file PDF + with open(save_path, "wb") as f: + f.write(file_bytes) + + pdf_success, detected_rows = process_pdf(save_path, ocr_engine, enhance_params) + return detected_rows if pdf_success else ("Error", "Error") + +if __name__ == '__main__': + demo = gr.Interface( + fn=handle_file_upload, + inputs=gr.File(type="binary", file_types=[".pdf"], label="Select your PDF"), + outputs=[ + gr.Textbox(label="Row 1"), + gr.Textbox(label="Row 2"), + ], + title="PDF Information Extractor", + description="Upload a PDF file to get basic information.", + allow_flagging="never" + ) + + demo.launch(share=True) diff --git a/example_adjust_image_brightness.py b/example_adjust_image_brightness.py new file mode 100644 index 0000000000000000000000000000000000000000..53a967e18eb7a0641b34672f04a619375fe1fe6e --- /dev/null +++ b/example_adjust_image_brightness.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 29 11:35:49 2020 + +Example: image enhancement +(spatial tone-mapping, local contrast enhancement, color enhancement) + +@author: Vasileios Vonikakis (bbonik@gmail.com) +""" + + + +import imageio +import matplotlib.pyplot as plt +from image_enhancement import enhance_image + + + +if __name__=="__main__": + + # select an image + filename = "../images/lisbon.jpg" + + image = imageio.imread(filename) # load image + + # setting up parameters + parameters = {} + parameters['local_contrast'] = 1.0 # no increase in details + parameters['mid_tones'] = 0.5 # middle of range + parameters['tonal_width'] = 0.5 # middle of range + parameters['areas_dark'] = 0.0 # no change in dark areas + parameters['areas_bright'] = 0.0 # no change in bright areas + parameters['saturation_degree'] = 1.0 # no change in color saturation + parameters['brightness'] = 0.5 # increase overall brightness by 50% + parameters['preserve_tones'] = True + parameters['color_correction'] = True + image_enhanced = enhance_image(image, parameters, verbose=False) + + # display results + plt.figure(figsize=(7,3)) + plt.subplot(1,2,1) + plt.imshow(image, vmin=0, vmax=255) + plt.title('Input image') + plt.axis('off') + plt.tight_layout() + + plt.subplot(1,2,2) + plt.imshow(image_enhanced, vmin=0, vmax=255) + plt.title('Enhanced image') + plt.axis('off') + plt.tight_layout() + + plt.show() + diff --git a/example_blend_exposures.py b/example_blend_exposures.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8df1a3e94a9c6148f8803169e4d71967d5c18f --- /dev/null +++ b/example_blend_exposures.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 29 11:35:49 2020 + +Example: image enhancement +(spatial tone-mapping, local contrast enhancement, color enhancement) + +@author: Vasileios Vonikakis (bbonik@gmail.com) +""" + + +import glob +import imageio +from image_enhancement import blend_expoures + + + +if __name__=="__main__": + + # select a collection of image exposures + + # exposure_filenames = glob.glob('../images/exposures_A*.jpg') + exposure_filenames = glob.glob('../images/exposures_B*.jpg') + + # put the exposures in a list + image_list = [] + for filename in exposure_filenames: + image = imageio.imread(filename) + image_list.append(image) + + # blend exposures + exposure_blend = blend_expoures( + image_list, + threshold_dark=0.35, + threshold_bright=0.65, + verbose=True + ) + + diff --git a/example_color_correction.py b/example_color_correction.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f4ed708e94203549080cb027b6d5fa61b300f4 --- /dev/null +++ b/example_color_correction.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 29 11:35:49 2020 + +Example: color correction (white balance) + +@author: Vasileios Vonikakis (bbonik@gmail.com) +""" + + + +import imageio +import matplotlib.pyplot as plt +from image_enhancement import correct_colors + + + +if __name__=="__main__": + + + # select an image + filename = "../images/strawberries.jpg" + # filename = "../images/napoleon.jpg" + # filename = "../images/shark.jpg" + # filename = "../images/underwater1.jpg" + # filename = "../images/underwater2.jpg" + + image = imageio.imread(filename) # load image + + image_enhanced = correct_colors(image, verbose=False) + + # display results + plt.figure(figsize=(7,3)) + plt.subplot(1,2,1) + plt.imshow(image, vmin=0, vmax=255) + plt.title('Input image') + plt.axis('off') + plt.tight_layout() + + plt.subplot(1,2,2) + plt.imshow(image_enhanced, vmin=0, vmax=1) + plt.title('Corrected colors') + plt.axis('off') + plt.tight_layout() + + plt.show() + diff --git a/example_enhance_image.py b/example_enhance_image.py new file mode 100644 index 0000000000000000000000000000000000000000..1d40923aec18ccf9b832bce995a45363e6dd2b78 --- /dev/null +++ b/example_enhance_image.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 29 11:35:49 2020 + +Example: image enhancement +(spatial tone-mapping, local contrast enhancement, color enhancement) + +@author: Vasileios Vonikakis (bbonik@gmail.com) +""" + + + +import imageio +import matplotlib.pyplot as plt +from image_enhancement import enhance_image + + + +if __name__=="__main__": + + # select an image + filename = "../images/alhambra1.jpg" + # filename = "../images/alhambra2.jpg" + # filename = "../images/lisbon.jpg" + + image = imageio.imread(filename) # load image + + + # setting up parameters + parameters = {} + parameters['local_contrast'] = 1.2 # 1.2x increase in details + parameters['mid_tones'] = 0.5 # middle of range + parameters['tonal_width'] = 0.5 # middle of range + parameters['areas_dark'] = 0.7 # 70% improvement in dark areas + parameters['areas_bright'] = 0.5 # 50% improvement in bright areas + parameters['brightness'] = 0.1 # slight increase in overall brightness + parameters['saturation_degree'] = 1.2 # 1.2x increase in color saturation + parameters['preserve_tones'] = True + parameters['color_correction'] = True + image_enhanced = enhance_image(image, parameters, verbose=False) + + # display results + plt.figure(figsize=(7,3)) + plt.subplot(1,2,1) + plt.imshow(image, vmin=0, vmax=255) + plt.title('Input image') + plt.axis('off') + + plt.subplot(1,2,2) + plt.imshow(image_enhanced, vmin=0, vmax=255) + plt.title('Enhanced image') + plt.axis('off') + plt.tight_layout() + + plt.show() + diff --git a/example_local_contrast_enhancement.py b/example_local_contrast_enhancement.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ae9be06e70ca6cfb237a3b3960ddb6e8023890 --- /dev/null +++ b/example_local_contrast_enhancement.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 29 11:35:49 2020 + +Example: enhancement of local details + +@author: Vasileios Vonikakis (bbonik@gmail.com) +""" + + + +import imageio +import matplotlib.pyplot as plt +from image_enhancement import enhance_image + + + +if __name__=="__main__": + + # select an image + filename = "../images/waves.jpg" + + image = imageio.imread(filename) # load image + + # setting up parameters + parameters = {} + parameters['local_contrast'] = 4 # 4x increase in details + parameters['mid_tones'] = 0.5 + parameters['tonal_width'] = 0.5 + parameters['areas_dark'] = 0.0 # no change in dark areas + parameters['areas_bright'] = 0.0 # no change in bright areas + parameters['saturation_degree'] = 2 # 2x increase in color saturation + parameters['brightness'] = 0.0 # no change in brightness + parameters['preserve_tones'] = False + parameters['color_correction'] = False + image_enhanced = enhance_image(image, parameters, verbose=False) + + # display results + plt.figure(figsize=(7,3)) + plt.subplot(1,2,1) + plt.imshow(image, vmin=0, vmax=255) + plt.title('Input image') + plt.axis('off') + plt.tight_layout() + + plt.subplot(1,2,2) + plt.imshow(image_enhanced, cmap='gray', vmin=0, vmax=1) + plt.title('Increased local contrast') + plt.axis('off') + plt.tight_layout() + + plt.show() + diff --git a/example_medical_image.py b/example_medical_image.py new file mode 100644 index 0000000000000000000000000000000000000000..2061559e8fd07d5a18dc5facf7fb40c01aaaa518 --- /dev/null +++ b/example_medical_image.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 29 11:35:49 2020 + +Example: enhancement of local details in medical image + +@author: Vasileios Vonikakis (bbonik@gmail.com) +""" + + + +import imageio +import matplotlib.pyplot as plt +from skimage import img_as_float +from skimage.color import rgb2gray +from image_enhancement import get_photometric_mask +from image_enhancement import apply_spatial_tonemapping +from image_enhancement import apply_local_contrast_enhancement + + +if __name__=="__main__": + + # select an image + filename = "../images/xray.jpg" + + image = imageio.imread(filename) # load image + + # grayscale and float + image = rgb2gray(image) + image = img_as_float(image) + + # get estimation of the local neighborhood + image_ph_mask = get_photometric_mask( + image=image, + verbose=False + ) + + # increase the local contrast of the grayscale image + image_contrast = apply_local_contrast_enhancement( + image=image, + image_ph_mask=image_ph_mask, + degree=2, # x2 local details + verbose=False + ) + + # apply spatial tonemapping on the previous stage + image_tonemapped = apply_spatial_tonemapping( + image=image_contrast, + image_ph_mask=image_ph_mask, + mid_tone=0.5, + tonal_width=0.5, + areas_dark=0.0, # no improvement in dark areas + areas_bright=0.8, # strong improvement in bright areas + preserve_tones=False, + verbose=False + ) + + # display results + plt.figure(figsize=(7,3)) + plt.subplot(1,2,1) + plt.imshow(image, cmap='gray', vmin=0, vmax=1) + plt.title('Input image') + plt.axis('off') + plt.tight_layout() + + plt.subplot(1,2,2) + plt.imshow(image_tonemapped, cmap='gray', vmin=0, vmax=1) + plt.title('Enhanced image') + plt.axis('off') + plt.tight_layout() + + plt.show() + diff --git a/full_code.py b/full_code.py new file mode 100644 index 0000000000000000000000000000000000000000..ee591b3d7777baa592bfd1127d37c5438578e6da --- /dev/null +++ b/full_code.py @@ -0,0 +1,15 @@ +import os +from pdf2image import convert_from_path + +# Path to your PDF file +pdf_path = 'example.pdf' + +# Convert PDF to images +images = convert_from_path(pdf_path, dpi=300) + +# Save each page as an image +for i, image in enumerate(images): + image.save(f'page_{i + 1}.jpg', 'JPEG') + + + diff --git a/image_enhancement-master.zip b/image_enhancement-master.zip new file mode 100644 index 0000000000000000000000000000000000000000..6a6a22a6095cbc2266bb7f09a2170199273649e8 --- /dev/null +++ b/image_enhancement-master.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4042829b8a4bd01b5fe3ab7d4e92be032e53bf40d6ef4554fdbc471d8ba134b8 +size 26840133 diff --git a/image_enhancement.py b/image_enhancement.py new file mode 100644 index 0000000000000000000000000000000000000000..d820ad3e8b9d0144964ee9c8d41a16267daded31 --- /dev/null +++ b/image_enhancement.py @@ -0,0 +1,1936 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Fri Jun 19 09:40:58 2020 + +Image enhancement functions + +@author: Vasileios Vonikakis (bbonik@gmail.com) +""" + +import math +import imageio +import numpy as np +import matplotlib.pyplot as plt +from skimage.color import rgb2gray +from skimage import img_as_float +from skimage.exposure import rescale_intensity, adjust_gamma + + +plt.close('all') + +#TODO: better memory management!!!! Too many copying of images. +#something like "inplace"? + + +def map_value( + value, + range_in=(0,1), + range_out=(0,1), + invert=False, + non_lin_convex=None, + non_lin_concave=None): + + ''' + --------------------------------------------------------------------------- + Map a scalar value to an output range in a linear/non-linear way + --------------------------------------------------------------------------- + + Map scalar values to a particular range, in a linear or non-linear way. + This can be helpful for adjusting the range and nonlinear response of + parameters. + + For more info on the non-linear functions check: + Vonikakis, V., Winkler, S. (2016). A center-surround framework for spatial + image processing. Proc. IS&T Human Vision & Electronic Imaging. + + + INPUTS + ------ + value: float + Input value to be mapped. + range_in: tuple (min,max) + Range of input value. The min and max values that the input value can + attain. + range_out: tuple (min,max) + Range of output value. The min and max values that the mapped input + value can attain. + invert: Bool + Invert or not the input value. If invert, then min->max and max->min. + non_lin_convex: None or float (0,inf) + If None, no non-linearity is applied. If float, then a convex + non-linearity is applied, which lowers the values, while not affecting + the min and max. non_lin_convex controls the steepness of the + non-linear mapping. Small values near zero, result in a steeper curve. + non_lin_concave: None or float (0,inf) + If None, no non-linearity is applied. If float, then a concave + non-linearity is applied, which increases the values, while not + affecting min and max. non_lin_concave controls the steepness of the + non-linear mapping. Small values near zero, result in a steeper curve. + + OUTPUT + ------ + Mapped value + + ''' + + # truncate value to within input range limits + if value > range_in[1]: value = range_in[1] + if value < range_in[0]: value = range_in[0] + + # map values linearly to [0,1] + value = (value - range_in[0]) / (range_in[1] - range_in[0]) + + # invert values + if invert is True: value = 1 - value + + # apply convex non-linearity + if non_lin_convex is not None: + value = (value * non_lin_convex) / (1 + non_lin_convex - value) + + # apply concave non-linearity + if non_lin_concave is not None: + value = ((1 + non_lin_concave) * value) / (non_lin_concave + value) + + # mapping value to the output range in a linear way + value = value * (range_out[1] - range_out[0]) + range_out[0] + + return value + + + + + +def get_membership_luts( + resolution=256, + lower_threshold=0.35, + upper_threshold=0.65, + verbose=False): + + ''' + --------------------------------------------------------------------------- + Creates 3 paramteric traspezoid membership functions + --------------------------------------------------------------------------- + + The trapezoid functions are defined as piece-wise functions between the + 0, lower_threshold, upper_threshold, 1. These trapezoid membership + functions can be used to filter out which parts of each exposure to be + used during exposure fusion. More details can be found in the following + paper: + + Vonikakis, V., Bouzos, O. & Andreadis, I. (2011). Multi-Exposure Image + Fusion Based on Illumination Estimation, SIPA2011 (pp.135-142), Greece. + + + INPUTS + ------ + resolution: int + The size of the LUT (how many inputs). + lower_threshold: float in the range [0,1] + The position of the lower inflection point of the trapezoid functions. + It should be always lower compared to the upper_threshold. + upper_threshold: float in the range [0,1] + The position of the upper inflection point of the trapezoid functions. + It should be always higher compared to the lower_threshold. + verbose: boolean + Display outputs. + + OUTPUT + ------ + lut_lower: float numpy array of size equal to resolution, values in [0,1] + The lower trepezoid membership function. + lut_mid: float numpy array of size equal to resolution, values in [0,1] + The middle trepezoid membership function. + lut_upper float numpy array of size equal to resolution, values in [0,1] + The upper trepezoid membership function. + + ''' + + + lut_lower = np.zeros(resolution, dtype='float') + lut_mid = np.zeros(resolution, dtype='float') + lut_upper = np.zeros(resolution, dtype='float') + + for i in range(resolution): + + i_float = i / (resolution - 1) + + # lower trapezoid membership function + if i_float <= lower_threshold: + lut_lower[i] = i_float / lower_threshold + else: + lut_lower[i] = 1 + + # middle trapezoid membership function + if i_float <= lower_threshold: + lut_mid[i] = i_float / lower_threshold + elif i_float <= upper_threshold: + lut_mid[i] = 1 + else: + lut_mid[i] = (1 - i_float) / (1 - upper_threshold) + + # upper trapezoid membership function + if i_float <= upper_threshold: + lut_upper[i] = 1 + else: + lut_upper[i] = (1 - i_float) / (1 - upper_threshold) + + + if verbose is True: + plt.figure() + + plt.subplot(1,3,1) + plt.plot(lut_lower) + plt.title('Lower') + plt.grid(True) + + plt.subplot(1,3,2) + plt.plot(lut_mid) + plt.title('Middle') + plt.grid(True) + + plt.subplot(1,3,3) + plt.plot(lut_upper) + plt.title('Upper') + plt.grid(True) + + plt.suptitle('Trapezoid membership functions') + plt.show() + + return lut_lower, lut_mid, lut_upper + + + + + + + +def get_sigmoid_lut( + resolution=256, + threshold=0.2, + non_linearirty=0.2, + verbose=False): + + ''' + --------------------------------------------------------------------------- + Creates a paramteric sigmoid function and stores it in a LUT + --------------------------------------------------------------------------- + + The sigmoid function is defined as a piece-wise function of 2 inverse + non-linearities. This allows full control of the inflection point + (threshold) and the degree of 'sharpness' of each non-linearity. The + non-linear curves used here are described in the paper: + Vonikakis, V., Winkler, S. (2016). A center-surround framework for spatial + image processing. Proc. IS&T Human Vision & Electronic Imaging. + + + INPUTS + ------ + resolution: int + The size of the LUT (how many inputs). + threshold: float in the range [0,1] + The position of the inflection point of the sigmoid function (0.5 in + the mid_tonedle of the range). + non_linearirty: float in range (0, inf) + Controls the non-linearity of the curve before and after the inflection + point. It should not be 0. The smaller it is (asymptotically to 0) the + 'sharper' the non-linearity. After ~5 it asymptotically approaches a + linerity. + verbose: boolean + Display outputs. + + OUTPUT + ------ + lut: float numpy array of size equal to resolution + The output sigmoid lut. + + ''' + + max_value = resolution - 1 # the maximum attainable value + thr = threshold * max_value # threshold in the range [0,resolution-1] + alpha = non_linearirty * max_value # controls non-linearity degree + beta = max_value - thr + if beta == 0: beta = 0.001 + + lut = np.zeros(resolution, dtype='float') + + for i in range(resolution): + + i_comp = i - thr # complement of i + + # upper part of the piece-wise sigmoid function + if i >= thr: + lut[i] = (((((alpha + beta) * i_comp) / (alpha + i_comp)) * + (1 / (2 * beta))) + 0.5) + + # lower part of the piece-wise sigmoid function + else: + lut[i] = (alpha * i) / (alpha - i_comp) * (1 / (2 * thr)) + + if verbose is True: + plt.figure() + plt.plot(lut) + plt.title('Sigmoid LUT | ' + + 'thr=' + str(int(thr)) + ' (' + str(round(threshold, 3)) + + ') | nonlin=' + str(int(alpha)) + + ' (' + str(round(non_linearirty, 3)) + ')') + plt.grid(True) + plt.tight_layout() + plt.show() + + return lut + + + +def get_photometric_mask( + image, + smoothing=0.2, + grayscale_out=True, + verbose=False): + ''' + --------------------------------------------------------------------------- + Estimate the photometric mask of an image by using edge-aware blurring + --------------------------------------------------------------------------- + + Applies strong blurring while preserving the strong edges of the image in + order to avoid halo artifacts. Inspired by the paper: + Shaked, Doron & Keshet, Renato. (2004). "Robust Recursive Envelope + Operators for Fast Retinex." + + + INPUTS + ------ + image: numpy array (WxH or WxHxK of uint8 [0.255] or float [0,1]) + Input image. + smoothing: float in the interval [0,1] + Value controlling the blur's strenght. 0 indicates no blur. Values + between 0-1 increase blurring strength while preserving edges. Values + above 1 approximate very strong gaussian blurring (large sigmas) where + no edges are preserved. Practically, values above 10 result into a + uniform image. + grayscale_out: logical + Whether or not the photometric mask is going to be grayscale or not. + If the input image is already grayscale (2D) then this parameter is + irrelevant. + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_ph_mask: numpy array of WxH or WxHxK of float [0,1] + Photometric mask of the input image. + + ''' + + + ''' + Intuition about the threshold and non_linearirty values of the LUTs + threshold: + The larger it is, the stronger the blurring, the better the local + contrast but also more halo artifacts (less edge preservation). + non_linearirty: + The lower it is, the more it preserves the edges, but also has more + 'bleeding' effects. + ''' + + + # internal parameters + THR_A = smoothing + THR_B = 0.04 # ~10/255 + NON_LIN = 0.12 # ~30/255 + LUT_RES = 256 + + # get sigmoid LUTs + lut_a = get_sigmoid_lut( + resolution=LUT_RES, + threshold=THR_A, + non_linearirty=NON_LIN, + verbose=verbose + ) + lut_a_max = len(lut_a) -1 + lut_b = get_sigmoid_lut( + resolution=LUT_RES, + threshold=THR_B, + non_linearirty=NON_LIN, + verbose=verbose + ) + lut_b_max = len(lut_b) -1 + + + # dealing with different number of channels + if len(image.shape) == 3: + if grayscale_out is True: + image_ph_mask = rgb2gray(image.copy()) # [0,1] 2D + else: + image_ph_mask = img_as_float(image.copy()) # [0,1] 3D + elif len(image.shape) == 2: + image_ph_mask = img_as_float(image.copy()) # [0,1] 2D + else: + image_ph_mask = img_as_float(image.copy()) # [0,1] ?D + + # if image is 2D, expand dimensions to 3D for code compatibility + # (filtering assumes a 3D image) + if len(image_ph_mask.shape) == 2: + image_ph_mask = np.expand_dims(image_ph_mask, axis=2) + + + # robust recursive envelope + + # up -> down + for i in range(1, image_ph_mask.shape[0]-1): + d = np.abs(image_ph_mask[i-1,:,:] - image_ph_mask[i+1,:,:]) # diff + d = lut_a[(d * lut_a_max).astype(int)] + image_ph_mask[i,:,:] = ((image_ph_mask[i,:,:] * d) + + (image_ph_mask[i-1,:,:] * (1-d))) + + # left -> right + for j in range(1, image_ph_mask.shape[1]-1): + d = np.abs(image_ph_mask[:,j-1,:] - image_ph_mask[:,j+1,:]) # diff + d = lut_a[(d * lut_a_max).astype(int)] + image_ph_mask[:,j,:] = ((image_ph_mask[:,j,:] * d) + + (image_ph_mask[:,j-1,:] * (1-d))) + + # down -> up + for i in range(image_ph_mask.shape[0]-2, 1, -1): + d = np.abs(image_ph_mask[i-1,:,:] - image_ph_mask[i+1,:,:]) # diff + d = lut_a[(d * lut_a_max).astype(int)] + image_ph_mask[i,:,:] = ((image_ph_mask[i,:,:] * d) + + (image_ph_mask[i+1,:,:] * (1-d))) + + # right -> left + for j in range(image_ph_mask.shape[1]-2, 1, -1): + d = np.abs(image_ph_mask[:,j-1,:] - image_ph_mask[:,j+1,:]) # diff + d = lut_b[(d * lut_b_max).astype(int)] + image_ph_mask[:,j,:] = ((image_ph_mask[:,j,:] * d) + + (image_ph_mask[:,j+1,:] * (1-d))) + + # up -> down + for i in range(1, image_ph_mask.shape[0]-1): + d = np.abs(image_ph_mask[i-1,:,:] - image_ph_mask[i+1,:,:]) # diff + d = lut_b[(d * lut_b_max).astype(int)] + image_ph_mask[i,:,:] = ((image_ph_mask[i,:,:] * d) + + (image_ph_mask[i-1,:,:] * (1-d))) + + + # convert back to 2D if grayscale is needed + if grayscale_out is True: + image_ph_mask = np.squeeze(image_ph_mask) + + + if verbose is True: + + plt.figure() + plt.subplot(1,2,1) + plt.imshow(image) + plt.title('Input image') + plt.axis('off') + + plt.subplot(1,2,2) + if grayscale_out is True: + plt.imshow(image_ph_mask, cmap='gray', vmin=0, vmax=1) + else: + plt.imshow(image_ph_mask, vmin=0, vmax=1) + plt.title('Photometric mask') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('Estimation of photometric mask') + plt.show() + + + return image_ph_mask + + + + + + +def blend_expoures( + exposure_list, + threshold_dark=0.35, + threshold_bright=0.65, + verbose=False + ): + + ''' + --------------------------------------------------------------------------- + Blend a collection of exposures to a single image + --------------------------------------------------------------------------- + + Function to blend a list of image exposures, using illumination estimation + across 2 spatial scales. + + Based on the following paper: + Vonikakis, V., Bouzos, O. & Andreadis, I. (2011). Multi-Exposure Image + Fusion Based on Illumination Estimation, SIPA2011 (pp.135-142), Greece. + + + INPUTS + ------ + exposure_list: list of numpy image arrays + List of numpy arrays (image exposures) which will be blended. Arrays + can be either grayscale, or color (3 channels). + threshold_dark: float in the interval [0,1] + Lower threshold for the membership function which will be applied to + the brightest exposure (long exposure). See above paper for more info. + threshold_dark < threshold_bright + threshold_bright: float in the interval [0,1] + Higher threshold for the membership function which will be applied to + the darkest exposure (short exposure). See above paper for more info. + threshold_bright > threshold_dark + verbose: boolean + Display outputs. + + OUTPUT + ------ + exposure_out: numpy array, float [0,1] + Output image of the blended exposures. If input images are grayscale, + exposure_out is also grayscale. If input images are color, then + exposure_out is also color. + + ''' + + # internal constants + SCALE_COARSE = 0.6 # [0,1], 0->fine, 1->coarse + SCALE_FINE = 0.2 # [0,1], 0->fine, 1->coarse + LUMINANCE_MIDDLE = 0.5 # middle of the luminance scale in [0,1] + GAMA_MAX = 2 # max gama to be used for darkening images + GAMA_MIN = 0.2 # min gama to be used for brightening images + LUT_RESOLUTION = 256 + + total_exposures = len(exposure_list) + + # color or grayscale + if len(exposure_list[0].shape) > 2: # check the 1st image of the list + color_exposures = True + else: + color_exposures = False + + + #--- sort exposures from darkest to brightest + + exposure_list_gray = [] + mean_luminance_list = [] + + if color_exposures is True: + exposure_list_red = [] + exposure_list_green = [] + exposure_list_blue = [] + + + for image in exposure_list: + image_gray = rgb2gray(image) + exposure_list_gray.append(image_gray) # grayscale + mean_luminance_list.append(image_gray.mean()) # mean luminance + if color_exposures is True: + exposure_list_red.append(img_as_float(image[:,:,0])) # red + exposure_list_green.append(img_as_float(image[:,:,1])) # green + exposure_list_blue.append(img_as_float(image[:,:,2])) # blue + + + # sort according to mean luminance + indx_lum_ascending = sorted( + range(len(mean_luminance_list)), + key=lambda i: mean_luminance_list[i] + ) + + if verbose is True: + print('Darkest to brightest exposure sequence:', indx_lum_ascending) + + + + # convert into a numpy array of hight x width x number of exposures + # (the 3rd dimension has the separate grayscale or color exposures) + exposure_array_gray = np.array(exposure_list_gray) + exposure_array_gray = np.moveaxis(exposure_array_gray, 0, -1) + exposure_array_gray = exposure_array_gray[:,:,indx_lum_ascending] + + if color_exposures is True: + + exposure_array_red = np.array(exposure_list_red) + exposure_array_red = np.moveaxis(exposure_array_red, 0, -1) + exposure_array_red = exposure_array_red[:,:,indx_lum_ascending] + + exposure_array_green = np.array(exposure_list_green) + exposure_array_green = np.moveaxis(exposure_array_green, 0, -1) + exposure_array_green = exposure_array_green[:,:,indx_lum_ascending] + + exposure_array_blue = np.array(exposure_list_blue) + exposure_array_blue = np.moveaxis(exposure_array_blue, 0, -1) + exposure_array_blue = exposure_array_blue[:,:,indx_lum_ascending] + + + + #--- generate illumination estimation in 2 spatial scales + + illumination_coarse = get_photometric_mask( + exposure_array_gray.copy(), + smoothing=SCALE_COARSE, + grayscale_out=False, # estimaste each channel separately + verbose=False) + + illumination_fine = get_photometric_mask( + exposure_array_gray.copy(), + smoothing=SCALE_FINE, + grayscale_out=False, # estimaste each channel separately + verbose=False) + + + # min max normalization for each exposure. + # make sure that each exposure has a 0 and 1 somewhere + + for i in range(total_exposures): + + illumination_coarse[:,:,i] = rescale_intensity( + illumination_coarse[:,:,i], + in_range='image', + out_range='dtype' + ) + + illumination_fine[:,:,i] = rescale_intensity( + illumination_fine[:,:,i], + in_range='image', + out_range='dtype' + ) + + + #--- Autoadjusting extreme exposures + # (This would be better if done in a data-driven way) + # if darkest exposure is too bright, darken it + # if brightest exposure is too dark, brighten it + + # darkest: if mean_lum>0.5 (too bright) + # scale gamma linearly in the interval [1, GAMA_MAX] + mean_lum = illumination_coarse[:,:,0].mean() + if mean_lum > LUMINANCE_MIDDLE: + gamma_new = map_value( + mean_lum, + range_in=(LUMINANCE_MIDDLE,1), + range_out=(1,GAMA_MAX) + ) + if verbose: + print( + 'Darkest coarse exposure too bright! Applying gamma:', + gamma_new + ) + illumination_coarse[:,:,0] = adjust_gamma( + image = illumination_coarse[:,:,0], + gamma = gamma_new + ) + + mean_lum = illumination_fine[:,:,0].mean() + if mean_lum > LUMINANCE_MIDDLE: + gamma_new = map_value( + mean_lum, + range_in=(LUMINANCE_MIDDLE,1), + range_out=(1,GAMA_MAX) + ) + if verbose: + print( + 'Darkest fine exposure too bright! Applying gamma:', + gamma_new + ) + illumination_fine[:,:,0] = adjust_gamma( + image = illumination_fine[:,:,0], + gamma = gamma_new + ) + + # brightest: if mean_lum<0.5 (too dark) + # scale gamma linearly in the interval [GAMA_MIN, 1] + mean_lum = illumination_coarse[:,:,-1].mean() + if mean_lum < LUMINANCE_MIDDLE: + gamma_new = map_value( + mean_lum, + range_in=(0,LUMINANCE_MIDDLE), + range_out=(GAMA_MIN,1) + ) + if verbose: + print( + 'Brightest coarse exposure too dark! Applying gamma:', + gamma_new + ) + illumination_coarse[:,:,-1] = adjust_gamma( + image = illumination_coarse[:,:,-1], + gamma = gamma_new + ) + + mean_lum = illumination_fine[:,:,-1].mean() + if mean_lum < LUMINANCE_MIDDLE: + gamma_new = map_value( + mean_lum, + range_in=(0,LUMINANCE_MIDDLE), + range_out=(GAMA_MIN,1) + ) + if verbose: + print( + 'Brightest fine exposure too dark! Applying gamma:', + gamma_new + ) + illumination_fine[:,:,-1] = adjust_gamma( + image = illumination_fine[:,:,-1], + gamma = gamma_new + ) + + + + #--- Apply membership functions to illumination to get exposure weights + + # generate membership function LUTs + weights_lower, weights_mid, weights_upper = get_membership_luts( + resolution=LUT_RESOLUTION, + lower_threshold=threshold_dark, # defines lower cutofd + upper_threshold=threshold_bright, # defines upper cutofd + verbose=verbose + ) + + lut_resolution = len(weights_lower) - 1 + + weights_coarse = np.zeros(illumination_coarse.shape, dtype=float) + weights_coarse[:,:,0] = (weights_lower[(illumination_coarse[:,:,0] * + lut_resolution).astype(int)]) + weights_coarse[:,:,1:-1] = (weights_mid[(illumination_coarse[:,:,1:-1] * + lut_resolution).astype(int)]) + weights_coarse[:,:,-1] = (weights_upper[(illumination_coarse[:,:,-1] * + lut_resolution).astype(int)]) + + weights_fine = np.zeros(illumination_fine.shape, dtype=float) + weights_fine[:,:,0] = (weights_lower[(illumination_fine[:,:,0] * + lut_resolution).astype(int)]) + weights_fine[:,:,1:-1] = (weights_mid[(illumination_fine[:,:,1:-1] * + lut_resolution).astype(int)]) + weights_fine[:,:,-1] = (weights_upper[(illumination_fine[:,:,-1] * + lut_resolution).astype(int)]) + + #TODO: apply local contrast enhancement to the exposure images, 2 times + # (one for each illumination scale) + + + #--- Weighted average of exposures based on the exposure weights + + # grayscale + exposure_coarse = weights_coarse * exposure_array_gray + exposure_coarse = (np.sum(exposure_coarse, axis=2) / + np.sum(weights_coarse, axis=2)) + exposure_fine = weights_fine * exposure_array_gray + exposure_fine = (np.sum(exposure_fine, axis=2) / + np.sum(weights_fine, axis=2)) + exposure_out_gray = (exposure_coarse + exposure_fine) / 2 + exposure_out = exposure_out_gray + + + if color_exposures is True: + + # red + exposure_coarse_red = weights_coarse * exposure_array_red + exposure_coarse_red = (np.sum(exposure_coarse_red, axis=2) / + np.sum(weights_coarse, axis=2)) + exposure_fine_red = weights_fine * exposure_array_red + exposure_fine_red = (np.sum(exposure_fine_red, axis=2) / + np.sum(weights_fine, axis=2)) + exposure_out_red = (exposure_coarse_red + exposure_fine_red) / 2 + + # green + exposure_coarse_green = weights_coarse * exposure_array_green + exposure_coarse_green = (np.sum(exposure_coarse_green, axis=2) / + np.sum(weights_coarse, axis=2)) + exposure_fine_green = weights_fine * exposure_array_green + exposure_fine_green = (np.sum(exposure_fine_green, axis=2) / + np.sum(weights_fine, axis=2)) + exposure_out_green = (exposure_coarse_green + exposure_fine_green) / 2 + + # blue + exposure_coarse_blue = weights_coarse * exposure_array_blue + exposure_coarse_blue = (np.sum(exposure_coarse_blue, axis=2) / + np.sum(weights_coarse, axis=2)) + exposure_fine_blue = weights_fine * exposure_array_blue + exposure_fine_blue = (np.sum(exposure_fine_blue, axis=2) / + np.sum(weights_fine, axis=2)) + exposure_out_blue = (exposure_coarse_blue + exposure_fine_blue) / 2 + + # combine all blended color channels to one image + exposure_out_color = np.zeros( + (exposure_out_gray.shape[0], exposure_out_gray.shape[1], 3), + dtype=float + ) + exposure_out_color[:,:,0] = exposure_out_red + exposure_out_color[:,:,1] = exposure_out_green + exposure_out_color[:,:,2] = exposure_out_blue + exposure_out = exposure_out_color + + + #--- Visualizations + + if verbose is True: + + # display intermediate stages of the method + + plt.figure() + + for i in range(total_exposures): + + plt.subplot(6,total_exposures,i+1) + plt.imshow(exposure_array_gray[:,:,i], cmap='gray') + plt.title('Exposure ' + str(i)) + plt.axis('off') + + plt.subplot(6,total_exposures,i+1+total_exposures) + plt.imshow(illumination_coarse[:,:,i], cmap='gray') + plt.title('ill.coarse ' + str(i)) + plt.axis('off') + + plt.subplot(6,total_exposures,i+1+(total_exposures*2)) + plt.imshow(illumination_fine[:,:,i], cmap='gray') + plt.title('ill.fine ' + str(i)) + plt.axis('off') + + plt.subplot(6,total_exposures,i+1+(total_exposures*3)) + plt.imshow(weights_coarse[:,:,i], cmap='gray') + plt.title('W.coarse ' + str(i)) + plt.axis('off') + + plt.subplot(6,total_exposures,i+1+(total_exposures*4)) + plt.imshow(weights_fine[:,:,i], cmap='gray') + plt.title('W.fine ' + str(i)) + plt.axis('off') + + plt.subplot(6,total_exposures,1+(total_exposures*5)) + plt.imshow(exposure_coarse, cmap='gray') + plt.title('Coarse blended') + plt.axis('off') + + plt.subplot(6,total_exposures,2+(total_exposures*5)) + plt.imshow(exposure_fine, cmap='gray') + plt.title('Fine blended') + plt.axis('off') + + plt.subplot(6,total_exposures,3+(total_exposures*5)) + plt.imshow(exposure_out_gray, cmap='gray') + plt.title('Final blend') + plt.axis('off') + + plt.suptitle('List of exposures') + plt.tight_layout() + plt.tight_layout() + plt.show() + + # display final color result + plt.figure() + grid = plt.GridSpec(total_exposures, total_exposures) + if color_exposures is False: + cmap = 'gray' + else: + cmap = None + + for i in range(total_exposures): + plt.subplot(grid[0,i]) + plt.imshow(exposure_list[indx_lum_ascending[i]], cmap=cmap) + plt.title('Exposure ' + str(i)) + plt.axis('off') + + plt.subplot(grid[1:,:]) + plt.imshow(exposure_out, cmap=cmap) + plt.title('Final blend') + plt.axis('off') + plt.tight_layout() + plt.suptitle('Full color blend') + plt.show() + + + return exposure_out + + + + + + + + +def apply_local_contrast_enhancement( + image, + image_ph_mask, + degree=1.5, + verbose=False): + + ''' + --------------------------------------------------------------------------- + Adjust local contrast in an image + --------------------------------------------------------------------------- + + Increase or decrease the level of local details (local contrast) in an + image. Details are defined as deviations from the local neighborhood + provided by the photometric mask. Dark regions receive also a boost in + local contrast. + + + INPUTS + ------ + image: numpy array of WxH of float [0,1] + Input grayscale image. + image_ph_mask: numpy array of WxH of float [0,1] + Grayscale image whose values represent the neighborhood of the pixels + of the input image. Usually, this image some type of edge aware + filtering, such as bilateral filtering, robust recursive envelopes etc. + degree: float [0,inf]. + How to change the local contrast. + 0: total attenuation of details. + <1: attenuation of details + 1: details unchanged + >1: increased local details + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_out: numpy array of WxH of float [0,1] + Output image with adjusted local contrast. + + ''' + + DARK_BOOST = 0.2 + THRESHOLD_DARK_TONES = 100 / 255 + detail_amplification_global = degree + + image_details = image - image_ph_mask # image details + + # special treatment for dark regions + detail_amplification_local = image_ph_mask / THRESHOLD_DARK_TONES + detail_amplification_local[detail_amplification_local>1] = 1 + detail_amplification_local = ((1 - detail_amplification_local) * + DARK_BOOST) + 1 # [1, 1.2] + + # apply all detail adjustements + image_details = (image_details * + detail_amplification_global * + detail_amplification_local) + + # add details back to the local neighborhood + image_out = image_ph_mask + image_details + + # stay within range + image_out = np.clip(a=image_out, a_min=0, a_max=1, out=image_out) + + if verbose is True: + + plt.figure() + plt.subplot(1,3,1) + plt.imshow(image, cmap='gray', vmin=0, vmax=1) + plt.title('Input image') + plt.axis('off') + + plt.subplot(1,3,2) + plt.imshow(image_ph_mask, cmap='gray', vmin=0, vmax=1) + plt.title('Ph. mask') + plt.axis('off') + + plt.subplot(1,3,3) + plt.imshow(image_out, cmap='gray', vmin=0, vmax=1) + plt.title('Output') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('Local contrast enhancement [x' + str(degree) + ']') + plt.show() + + return image_out + + + + +def apply_spatial_tonemapping( + image, + image_ph_mask, + mid_tone=0.5, + tonal_width=0.5, + areas_dark=0.5, + areas_bright=0.5, + preserve_tones = True, + verbose=True): + ''' + --------------------------------------------------------------------------- + Apply spatially variable tone mapping based on the local neighborhood + --------------------------------------------------------------------------- + + Applies different tone mapping curves in each pixel based on its surround. + For surround, the photometric mask is used. Alternatively, other filters + could be used, like gaussian, bilateral filter, edge-avoiding wavelets etc. + Dark pixels are brightened, bright pixels are darkened, and pixels in the + mid_tonedle of the tone range are minimally affected. More information + about the technique can be found in the following papers: + + Related publications: + Vonikakis, V., Andreadis, I., & Gasteratos, A. (2008). Fast centre-surround + contrast modification. IET Image processing 2(1), 19-34. + Vonikakis, V., Winkler, S. (2016). A center-surround framework for spatial + image processing. Proc. IS&T Human Vision & Electronic Imaging. + + + INPUTS + ------ + image: numpy array of WxH of float [0,1] + Input grayscale image with values in the interval [0,1]. + image_ph_mask: numpy array of WxH of float [0,1] + Grayscale image whose values represent the neighborhood of the pixels + of the input image. Usually, this image some type of edge aware + filtering, such as bilateral filtering, robust recursive envelopes etc. + mid_tone: float [0,1] + The mid point between the 'dark' and 'bright' tones. This is equivalent + to a pixel value [0,255], but in the interval [0,1]. + tonal_width: float [0,1] + The range of pixel values that will be affected by the correction. + Lower values will localize the enhancement only in a narrow range of + pixel values, whereas for higher values the enhancement will extend to + a greater range of pixel values. + areas_dark: float [0,1] + Degree of enhencement in the dark image areas (0 = no enhencement) + areas_bright: float [0,1] + Degree of enhencement in the bright image areas (0 = no enhencement) + preserve_tones: boolean + Whether or not to preserve well-exposed tones around the middle of the + range. + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_tonemapped: numpy array of WxH of float [0,1] + Tonemapped grayscale image. + + ''' + + # defining parameters + EPSILON = 1 / 256 + + + # adjust range and non-linear response of parameters + mid_tone = map_value( + value=mid_tone, + range_in=(0,1), + range_out=(0,1), + invert=False, + non_lin_convex=None, + non_lin_concave=None + ) + + tonal_width = map_value( + value=tonal_width, + range_in=(0,1), + range_out=(EPSILON,1), + invert=False, + non_lin_convex=0.1, + non_lin_concave=None + ) + + areas_dark = map_value( + value=areas_dark, + range_in=(0,1), + range_out=(0,5), + invert=True, + non_lin_convex=0.05, + non_lin_concave=None + ) + + areas_bright = map_value( + value=areas_bright, + range_in=(0,1), + range_out=(0,5), + invert=True, + non_lin_convex=0.05, + non_lin_concave=None + ) + + + + # spatial tone-mapping + + # lower tones (below mid_tone level) + image_lower = image.copy() + image_lower[image_lower>=mid_tone] = 0 + alpha = (image_ph_mask ** 2) / tonal_width + tone_continuation_factor = mid_tone / (mid_tone + EPSILON - image_ph_mask) + alpha = alpha * tone_continuation_factor + areas_dark + image_lower = (image_lower * (alpha + 1)) / (alpha + image_lower) + + # upper tones (above mid_tone level) + image_upper = image.copy() + image_upper[image_upper 0.04045] = 0 + image_lower = image_lower / 12.92 + + # upper part of the piecewise formula + image_upper = image_srgb.copy() + image_upper = image_upper + 0.055 + image_upper[image_upper <= (0.04045+0.055)] = 0 + image_upper = image_upper / 1.055 + image_upper = image_upper ** 2.4 + + image_linear = image_lower + image_upper # combine into the final result + + if verbose is True: + + plt.figure() + plt.subplot(1,2,1) + plt.imshow(image_srgb, vmin=0, vmax=1) + plt.title('Image sRGB') + plt.axis('off') + + plt.subplot(1,2,2) + plt.imshow(image_linear, vmin=0, vmax=1) + plt.title('Image linear') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('sRGB -> linear space') + plt.show() + + return image_linear + + + + + +def linear_to_srgb(image_linear, verbose=False): + + ''' + --------------------------------------------------------------------------- + Transform an image from linear to sRGB color space + --------------------------------------------------------------------------- + + The function re-applies the main non-linearities associated with the sRGB + color space. The transformation formula can be found in EasyRGB website: + https://www.easyrgb.com/en/math.php + + Note that the formulas may look slightly different. This is because they + have been altered in order to implement them in a vectorized way, avoiding + for loops. As such, an image is partitioned in 2 parts image_upper and + image_lower, which implement separate parts of the piece-wise color + transformation formula. + + + INPUTS + ------ + image_linear: numpy array of WxHx3 of float [0,1] + Input color image with values in the interval [0,1]. + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_srgb: numpy array of WxHx3 of uint8 [0,255] + Output color sRGB image with values in the interval [0,255]. + + ''' + + + # dealing with different input dimensions + dimensions = len(image_linear.shape) + if dimensions == 1: + image_linear = np.expand_dims(image_linear, axis=2) # 3rd dimension + + image_linear = img_as_float(image_linear) # [0,1] + + # lower part of the piecewise formula + image_lower = image_linear.copy() + image_lower[image_lower > 0.0031308] = 0 + image_lower = image_lower * 12.92 + + # upper part of the piecewise formula + image_upper = image_linear.copy() + image_upper[image_upper <= 0.0031308] = 0 + image_upper = image_upper ** (1/2.4) + image_upper = image_upper * 1.055 + image_upper = image_upper - 0.055 + + image_srgb = image_lower + image_upper + image_srgb = np.clip(a=image_srgb, a_min=0, a_max=1, out=image_srgb) + + + if verbose is True: + + plt.figure() + plt.subplot(1,2,1) + plt.imshow(image_linear, vmin=0, vmax=1) + plt.title('Image linear') + plt.axis('off') + + plt.subplot(1,2,2) + plt.imshow(image_srgb, vmin=0, vmax=1) + plt.title('Image sRGB') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('Linear space -> sRGB') + plt.show() + + return (image_srgb * 255).astype(np.uint8) + + + + +def transfer_graytone_to_color(image_color, image_graytone, verbose=False): + + ''' + --------------------------------------------------------------------------- + Transfer grayscale tones to a color image + --------------------------------------------------------------------------- + + Transfers the tones of a guide grayscale image to the color version of the + same image, by using linear color ratios. It first brings the image from + the sRGB color space back to the linear color space. It estimates color + ratios of the grayscale color image with the tone-mapped grayscale guide + image. It then applies the color ratios on the 3 color channels. Finally, + it brings back the image to the sRGB color space (gamma corrected). Is the + input image is in another color space (Adobe RGB), a different + transformation could be used. However, results will not be that much + different. + + Related publication: + Chengho Hsin, Zong Wei Lee, Zheng Zhan Lee, and Shaw-Jyh Shin, "Color + preservation for tone reproduction and image enhancement", Proc. SPIE 9015, + Color Imaging XIX, 2014 + + + INPUTS + ------ + image_color: numpy array of WxHx3 of uint8 [0,255] + Input color image. + image_graytone: numpy array of WxH of float [0,1] + Grayscale version of the image_color which has been tonemapped and it + will be used as a guide to transfer the same tonemapping to the color + image. + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_colortone: numpy array of WxHx3 of uint8 [0,255] + Output color image with transfered tonemapping. + + ''' + + + EPSILON = 1 / 256 + + # bring both color and graytone to linear space + image_color_linear = srgb_to_linear(image_color.copy(), verbose=False) + image_graytone_linear = srgb_to_linear(image_graytone.copy(),verbose=False) + image_gray_linear = rgb2gray(image_color_linear.copy()) + image_gray_linear[image_gray_linear==0] = EPSILON # for the division later + + # tone ratio of linear images: improved/original + tone_ratio = image_graytone_linear / image_gray_linear +# tone_ratio[np.isinf(tone_ratio)] = 0 +# tone_ratio[np.isnan(tone_ratio)] = 0 + + # apply the tone ratios to the color image + image_colortone_linear = image_color_linear * np.dstack([tone_ratio] * 3) + + # make sure it's within limits + image_colortone_linear = np.clip( + a=image_colortone_linear, + a_min=0, + a_max=1, + out=image_colortone_linear + ) + + # bring back to gamma-corrected sRGB space for visualization + image_colortone = linear_to_srgb(image_colortone_linear, verbose=False) + + # display results + if verbose is True: + + plt.figure() + plt.subplot(2,4,1) + plt.imshow(image_color, vmin=0, vmax=255) + plt.title('Color') + plt.axis('off') + + plt.subplot(2,4,5) + plt.imshow(image_color_linear, vmin=0, vmax=1) + plt.title('Color linear') + plt.axis('off') + + plt.subplot(2,4,2) + plt.imshow(image_graytone, cmap='gray', vmin=0, vmax=1) + plt.title('Graytone') + plt.axis('off') + + plt.subplot(2,4,6) + plt.imshow(image_graytone_linear, cmap='gray', vmin=0, vmax=1) + plt.title('Graytone linear') + plt.axis('off') + + plt.subplot(2,4,7) + plt.imshow(tone_ratio, cmap='gray') + plt.title('Tone ratios') + plt.axis('off') + + plt.subplot(2,4,4) + plt.imshow(image_colortone, vmin=0, vmax=255) + plt.title('Colortone') + plt.axis('off') + + plt.subplot(2,4,8) + plt.imshow(image_colortone_linear, vmin=0, vmax=1) + plt.title('Colortone linear') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('Transfering gray tones to color') + plt.show() + + + return image_colortone + + + + + + + + +def change_color_saturation( + image_color, + image_ph_mask=None, + sat_degree=1.5, + verbose=False): + + ''' + --------------------------------------------------------------------------- + Adjust color saturation of an image + --------------------------------------------------------------------------- + + Increase or decrease the saturation (vibrance) of colors in an image. This + implements a simpler approach rather than using the HSV color space to + adjust S. In my experiments HSV-based saturation adjustment was not as good + and it exhibited some kind of 'color noise'. This approach is aesthetically + better. The use of photometric_mask is optional, in case you would like to + treat dark areas (where saturation is usually lower) differently. + + + INPUTS + ------ + image_color: numpy array of WxHx3 of float [0,1] + Input color image. + image_ph_mask: numpy array of WxH of float [0,1] or None + Grayscale image whose values represent the neighborhood of the pixels + of the input image. If None, saturation adjustment is applied globally + to all pixels. If not None, then dark regions are treated differently + and get an additional boost in saturation. + + sat_degree': float [0,inf]. + How to change the color saturation. 0: no color (grayscale), + <1: reduced color saturation, 1: color saturation unchanged + >1: increased color saturation + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_new_sat: numpy array of WxHx3 of float [0,1] + Output image with adjusted saturation. + + ''' + + LOCAL_BOOST = 0.2 + THRESHOLD_DARK_TONES = 100 / 255 + + #TODO: return the same image type + image_color = img_as_float(image_color) # [0,1] + + # define gray scale + image_gray = (image_color[:,:,0] + + image_color[:,:,1] + + image_color[:,:,2]) / 3 + image_gray = np.dstack([image_gray] * 3) # grayscale with 3 channels + + image_delta = image_color - image_gray # deviations from gray + + # defining local color amplification degree + if image_ph_mask is not None: + detail_amplification_local = image_ph_mask / THRESHOLD_DARK_TONES + detail_amplification_local[detail_amplification_local>1] = 1 + detail_amplification_local = ((1 - detail_amplification_local) * + LOCAL_BOOST) + 1 # [1, 1.2] + detail_amplification_local = np.dstack( + [detail_amplification_local] * 3) # 3 channels + else: + detail_amplification_local = 1 + + image_new_sat = (image_gray + + image_delta * sat_degree * detail_amplification_local) + + image_new_sat = np.clip( + a=image_new_sat, + a_min=0, + a_max=1, + out=image_new_sat + ) + + if verbose is True: + + plt.figure() + plt.subplot(1,2,1) + plt.imshow(image_color, vmin=0, vmax=1) + plt.title('Input image') + plt.axis('off') + + plt.subplot(1,2,2) + plt.imshow(image_new_sat, vmin=0, vmax=1) + plt.title('New saturation [x' + str(sat_degree) + ']') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('Color saturation adjustment') + plt.show() + + + return image_new_sat + + + + + + +def correct_colors(image, verbose): + ''' + --------------------------------------------------------------------------- + Correct image colors (remove color casts) + --------------------------------------------------------------------------- + + Implements a simple color correction using the Gray World Color Assumption + and White Point Correction. + + Related publication: + Vonikakis, V., Arapakis, I. & Andreadis, I. (2011). Combining Gray-World + assumption, White-Point correction and power transformation for automatic + white balance. International Workshop on Advanced Image Technology (IWAIT), + paper number 1569353295, Jakarta Indonesia. + + INPUTS + ------ + image: numpy array of WxHx3 of uint8 [0,255] + Input color image. + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_out: numpy array of WxHx3 of float [0,1] + Output image with adjusted colors. + + ''' + + image_out = img_as_float(image.copy()) # [0,1] + +# # simple gray world color correction +# image_out[:,:,0] = (image_out[:,:,0] / image_out[:,:,0].mean()) * 0.5 +# image_out[:,:,1] = (image_out[:,:,1] / image_out[:,:,1].mean()) * 0.5 +# image_out[:,:,2] = (image_out[:,:,2] / image_out[:,:,2].mean()) * 0.5 + + # mean of all channels + image_mean = (image_out[:,:,0].mean() + + image_out[:,:,1].mean() + + image_out[:,:,2].mean()) / 3 + + # logarithm base to which each channel will be raised + base_r = image_out[:,:,0].mean() / image_out[:,:,0].max() + base_g = image_out[:,:,1].mean() / image_out[:,:,1].max() + base_b = image_out[:,:,2].mean() / image_out[:,:,2].max() + + # the power to which each channel will be raised + power_r = math.log(image_mean, base_r) + power_g = math.log(image_mean, base_g) + power_b = math.log(image_mean, base_b) + + # separately applying different color correction powers to each channel + image_out[:,:,0] = (image_out[:,:,0] / image_out[:,:,0].max()) ** power_r + image_out[:,:,1] = (image_out[:,:,1] / image_out[:,:,1].max()) ** power_g + image_out[:,:,2] = (image_out[:,:,2] / image_out[:,:,2].max()) ** power_b + + if verbose is True: + + plt.figure() + plt.subplot(1,2,1) + plt.imshow(image) + plt.title('Input image') + plt.axis('off') + + plt.subplot(1,2,2) + plt.imshow(image_out, vmin=0, vmax=1) + plt.title('Corrected colors') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('Gray world color correction') + plt.show() + + return image_out + + + + + +def adjust_brightness(image, degree=0, verbose=False): + ''' + --------------------------------------------------------------------------- + Apply global tone mapping on a grayscale image + --------------------------------------------------------------------------- + + Applies a single tone mapping curve in all the pixels of a grayscale image. + Depending on the parameters, the image can be brighten or darken. The set + of curves used are similar to gamma functions, but are inspired from the + Naka-Rushton function and exhibit symmetry and better local contrast. More + information about the technique can be found in the following papers: + + Related publications: + Vonikakis, V., Winkler, S. (2016). A center-surround framework for spatial + image processing. Proc. IS&T Human Vision & Electronic Imaging. + + INPUTS + ------ + image: numpy array of WxH of float [0,1] + Input grayscale image with values in the interval [0,1]. + degree: float [-1,1] + The strength of the uniform tone mapping function. + [-1,0): darken image. Closer to -1 means more agressive darkening + 0: Unchanged tones + (0,1]: brighten image. Closer to 1 means more agressive brightening + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_tonemapped: numpy array of WxH of float [0,1] + Tonemapped grayscale image. + + ''' + + + EPSILON = 1 / 256 # what we consider minimum value + + # adjust range and non-linear response of parameters + # unpack information: darken or brighten and the degree + if degree > 0: + brighten = True + else: + brighten = False + + degree = abs(degree) # [0,1] + + alpha = map_value( + value=degree, + range_in=(0,1), + range_out=(0,5), # from the paper: 5x brings close to linear + invert=True, # from the paper + non_lin_convex=0.05, # adding linearity to the response + non_lin_concave=None + ) + + alpha = alpha + EPSILON # to avoid division by zero + + + # applying global tone-mapping + if degree != 0: + image_brightness = image.copy() + + if brighten is True: + image_brightness = ((image_brightness * (alpha + 1)) / + (alpha + image_brightness)) + else: + image_brightness = ((image_brightness * alpha) / + (alpha + 1 - image_brightness)) + else: image_brightness = image + + + + if verbose is True: + + plt.figure() + + plt.subplot(1,2,1) + plt.imshow(image, cmap='gray', vmin=0, vmax=1) + plt.title('Input image') + plt.axis('off') + + plt.subplot(1,2,2) + plt.imshow(image_brightness, cmap='gray', vmin=0, vmax=1) + plt.title('Adjusted brightness image') + plt.axis('off') + + plt.tight_layout(True) + plt.suptitle('Adjusting brightness') + plt.show() + + + return image_brightness + + + + + + +def enhance_image(image, parameters, verbose=False): + + ''' + --------------------------------------------------------------------------- + Image enhancement + --------------------------------------------------------------------------- + + Image enhancement pipeline, with spatial tone mapping, local contrast + enhancement and color saturation adjustment. The 3 steps are fully + decoupled and the user can independently define the enhancement degree of + each stage. + + Related publications: + Vonikakis, V., Andreadis, I., & Gasteratos, A. (2008). Fast centre-surround + contrast modification. IET Image processing 2(1), 19-34. + Vonikakis, V., Winkler, S. (2016). A center-surround framework for spatial + image processing. Proc. IS&T Human Vision & Electronic Imaging. + + + INPUTS + ------ + image: numpy array of WxHx3 of uint8 [0,255] + Input color image with values in the interval [0,255]. + parameters: dictionary + 'local_contrast': float [0,inf]. + 0: total attenuation of details. + <1: attenuation of details + 1: details unchanged + >1: increased local details + 'mid_tones': float [0,1] + 'tonal_width': float [0,1] + 'areas_dark': float [0,1] + 0: no enhancement + 1: strongest enhancement + 'areas_bright': float [0,1] + 0: no enhancement + 1: strongest enhancement + 'brightness': float [-1,1] + >=-1: darken image + 0: unchanged + <=1: brighten image + 'preserve_tones': boolean + 'color_correction': boolean + 'saturation_degree': float [0,inf]. + 0: no color (grayscale). + <1: reduced color saturation + 1: color saturation unchanged + >1: increased color saturation + verbose: boolean + Display outputs. + + OUTPUT + ------ + image_colortone_saturation: numpy array of WxHx3 of uint8 [0,255] + Output enhanced image. + + ''' + + + #TODO: add an automatic parameter estimation stage (machine learning) + + + # sanity check for type, range and defaults + + if 'local_contrast' in parameters: + parameters['local_contrast'] = float(parameters['local_contrast']) + if parameters['local_contrast'] < 0: parameters['local_contrast'] = 0 + else: parameters['local_contrast'] = 1.2 # default: slight increase + + if 'mid_tones' in parameters: + parameters['mid_tones'] = float(parameters['mid_tones']) + if parameters['mid_tones'] > 1: parameters['mid_tones'] = 1 + if parameters['mid_tones'] < 0: parameters['mid_tones'] = 0 + else: parameters['mid_tones'] = 0.5 # default: middle of the range + + if 'tonal_width' in parameters: + parameters['tonal_width'] = float(parameters['tonal_width']) + if parameters['tonal_width'] > 1: parameters['tonal_width'] = 1 + if parameters['tonal_width'] < 0: parameters['tonal_width'] = 0 + else: parameters['tonal_width'] = 0.5 # default: middle of the range + + if 'areas_dark' in parameters: + parameters['areas_dark'] = float(parameters['areas_dark']) + if parameters['areas_dark'] > 1: parameters['areas_dark'] = 1 + if parameters['areas_dark'] < 0: parameters['areas_dark'] = 0 + else: parameters['areas_dark'] = 0.2 # default: gentle increase + + if 'areas_bright' in parameters: + parameters['areas_bright'] = float(parameters['areas_bright']) + if parameters['areas_bright'] > 1: parameters['areas_bright'] = 1 + if parameters['areas_bright'] < 0: parameters['areas_bright'] = 0 + else: parameters['areas_bright'] = 0.2 # default: gentle increase + + if 'brightness' in parameters: + parameters['brightness'] = float(parameters['brightness']) + if parameters['brightness'] > 1: parameters['brightness'] = 1 + if parameters['brightness'] < -1: parameters['brightness'] = -1 + else: parameters['brightness'] = 0.1 # default: gentle increase + + if 'preserve_tones' in parameters: + parameters['preserve_tones'] = bool(parameters['preserve_tones']) + else: parameters['preserve_tones'] = True # default: preserve tones + + if 'color_correction' in parameters: + parameters['color_correction'] = bool(parameters['color_correction']) + else: parameters['color_correction'] = False # default: no correction + + if 'saturation_degree' in parameters: + parameters['saturation_degree'] = float(parameters['saturation_degree']) + if parameters['saturation_degree'] < 0: parameters['saturation_degree'] = 0 + else: parameters['saturation_degree'] = 1.2 # default: slight increase + + + + + # get photometric mask, as a guide for spatial-tone mapping + image_ph_mask = get_photometric_mask( + image=image, + verbose=verbose + ) + + # increase the local contrast of the grayscale image + image_contrast = apply_local_contrast_enhancement( + image=rgb2gray(image.copy()), + image_ph_mask=image_ph_mask, + degree=parameters['local_contrast'], + verbose=verbose + ) + + # apply spatial tonemapping on the previous stage + image_tonemapped = apply_spatial_tonemapping( + image=image_contrast, + image_ph_mask=image_ph_mask, + mid_tone=parameters['mid_tones'], + tonal_width=parameters['tonal_width'], + areas_dark=parameters['areas_dark'], + areas_bright=parameters['areas_bright'], + preserve_tones=parameters['preserve_tones'], + verbose=verbose + ) + + image_brightness = adjust_brightness( + image_tonemapped, + degree=parameters['brightness'], + verbose=verbose + ) + + # transfer the enhancement on the color image (in the linear color space) + image_colortone = transfer_graytone_to_color( + image_color=image, + image_graytone=image_brightness, + verbose=verbose + ) + + # apply color correction (if needed) + if parameters['color_correction'] is True: + image_colortone = correct_colors( + image=image_colortone, + verbose=verbose + ) + + # adjust the color saturation + image_colortone_saturation = change_color_saturation( + image_color=image_colortone, + image_ph_mask=image_ph_mask, + sat_degree=parameters['saturation_degree'], + verbose = verbose, + ) + + # TODO: add a denoising stage + + # display results + if verbose is True: + + plt.figure() + plt.subplot(2,3,1) + plt.imshow(image, vmin=0, vmax=255) + plt.title('Input image') + plt.axis('off') + plt.tight_layout() + + plt.subplot(2,3,4) + plt.imshow(image_ph_mask, cmap='gray', vmin=0, vmax=1) + plt.title('Photometric mask') + plt.axis('off') + plt.tight_layout() + + plt.subplot(2,3,5) + plt.imshow(image_contrast, cmap='gray', vmin=0, vmax=1) + plt.title('Local contrast enhancement') + plt.axis('off') + plt.tight_layout() + + plt.subplot(2,3,2) + plt.imshow(image_colortone, vmin=0, vmax=255) + plt.title('Spatial tone mapping') + plt.axis('off') + plt.tight_layout() + + plt.subplot(2,3,3) + plt.imshow(image_colortone_saturation, vmin=0, vmax=255) + plt.title('Increased saturation') + plt.axis('off') + plt.tight_layout() + + + return image_colortone_saturation + + + + + + + +# def fuse_exposures(ls_images): + + + + + + + + + + + + + + + +if __name__=="__main__": + + filename = "../images/lisbon.jpg" + image = imageio.imread(filename) # load image + + # setting up parameters + parameters = {} + parameters['local_contrast'] = 1.5 # 1.5x increase in details + parameters['mid_tones'] = 0.5 + parameters['tonal_width'] = 0.5 + parameters['areas_dark'] = 0.7 # 70% improvement in dark areas + parameters['areas_bright'] = 0.5 # 50% improvement in bright areas + parameters['saturation_degree'] = 1.2 # 1.2x increase in color saturation + parameters['brightness'] = 0.1 # slight increase in brightness + parameters['preserve_tones'] = True + parameters['color_correction'] = False + image_enhanced = enhance_image(image, parameters, verbose=False) + + # display results + plt.figure() + plt.subplot(1,2,1) + plt.imshow(image, vmin=0, vmax=255) + plt.title('Input image') + plt.axis('off') + plt.tight_layout() + + plt.subplot(1,2,2) + plt.imshow(image_enhanced, vmin=0, vmax=255) + plt.title('Enhanced image') + plt.axis('off') + plt.tight_layout() + + plt.show() + + + + + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..4af7b2d884ba9ce17cde94c76335eab1c85ce618 --- /dev/null +++ b/main.py @@ -0,0 +1,385 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import copy +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np + +import os + +from cal_rec_boxes import CalRecBoxes +from ch_ppocr_cls import TextClassifier +from ch_ppocr_det import TextDetector +from ch_ppocr_rec import TextRecognizer +from utils import ( + LoadImage, + UpdateParameters, + VisRes, + add_round_letterbox, + get_logger, + increase_min_side, + init_args, + read_yaml, + reduce_max_side, + update_model_path, +) + +root_dir = Path(__file__).resolve().parent +DEFAULT_CFG_PATH = root_dir / "config.yaml" +logger = get_logger("RapidOCR") + + +class RapidOCR: + def __init__(self, config_path: Optional[str] = None, **kwargs): + if config_path is not None and Path(config_path).exists(): + config = read_yaml(config_path) + else: + config = read_yaml(DEFAULT_CFG_PATH) + config = update_model_path(config) + + if kwargs: + updater = UpdateParameters() + config = updater(config, **kwargs) + + global_config = config["Global"] + self.print_verbose = global_config["print_verbose"] + self.text_score = global_config["text_score"] + self.min_height = global_config["min_height"] + self.width_height_ratio = global_config["width_height_ratio"] + + self.use_det = global_config["use_det"] + self.text_det = TextDetector(config["Det"]) + + # self.use_cls = global_config["use_cls"] + # self.text_cls = TextClassifier(config["Cls"]) + + self.use_rec = global_config["use_rec"] + self.text_rec = TextRecognizer(config["Rec"]) + + self.load_img = LoadImage() + self.max_side_len = global_config["max_side_len"] + self.min_side_len = global_config["min_side_len"] + + self.cal_rec_boxes = CalRecBoxes() + + def __call__( + self, + img_content: Union[str, np.ndarray, bytes, Path], + use_det: Optional[bool] = None, + use_cls: Optional[bool] = None, + use_rec: Optional[bool] = None, + **kwargs, + ) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]: + use_det = self.use_det if use_det is None else use_det + use_cls = self.use_cls if use_cls is None else use_cls + use_rec = self.use_rec if use_rec is None else use_rec + return_word_box = False + if kwargs: + box_thresh = kwargs.get("box_thresh", 0.5) + unclip_ratio = kwargs.get("unclip_ratio", 1.6) + text_score = kwargs.get("text_score", 0.5) + return_word_box = kwargs.get("return_word_box", False) + self.text_det.postprocess_op.box_thresh = box_thresh + self.text_det.postprocess_op.unclip_ratio = unclip_ratio + self.text_score = text_score + + img = self.load_img(img_content) + + raw_h, raw_w = img.shape[:2] + op_record = {} + img, ratio_h, ratio_w = self.preprocess(img) + op_record["preprocess"] = {"ratio_h": ratio_h, "ratio_w": ratio_w} + + dt_boxes, cls_res, rec_res = None, None, None + det_elapse, cls_elapse, rec_elapse = 0.0, 0.0, 0.0 + + if use_det: + img, op_record = self.maybe_add_letterbox(img, op_record) + dt_boxes, det_elapse = self.auto_text_det(img) + if dt_boxes is None: + return None, None + + img = self.get_crop_img_list(img, dt_boxes) + + # if use_cls: + # img, cls_res, cls_elapse = self.text_cls(img) + + if use_rec: + rec_res, rec_elapse = self.text_rec(img, return_word_box) + + if dt_boxes is not None and rec_res is not None and return_word_box: + rec_res = self.cal_rec_boxes(img, dt_boxes, rec_res) + for rec_res_i in rec_res: + if rec_res_i[2]: + rec_res_i[2] = ( + self._get_origin_points(rec_res_i[2], op_record, raw_h, raw_w) + .astype(np.int32) + .tolist() + ) + + if dt_boxes is not None: + dt_boxes = self._get_origin_points(dt_boxes, op_record, raw_h, raw_w) + + ocr_res = self.get_final_res( + dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse + ) + return ocr_res + + def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, float, float]: + h, w = img.shape[:2] + max_value = max(h, w) + ratio_h = ratio_w = 1.0 + if max_value > self.max_side_len: + img, ratio_h, ratio_w = reduce_max_side(img, self.max_side_len) + + h, w = img.shape[:2] + min_value = min(h, w) + if min_value < self.min_side_len: + img, ratio_h, ratio_w = increase_min_side(img, self.min_side_len) + return img, ratio_h, ratio_w + + def maybe_add_letterbox( + self, img: np.ndarray, op_record: Dict[str, Any] + ) -> Tuple[np.ndarray, Dict[str, Any]]: + h, w = img.shape[:2] + + if self.width_height_ratio == -1: + use_limit_ratio = False + else: + use_limit_ratio = w / h > self.width_height_ratio + + if h <= self.min_height or use_limit_ratio: + padding_h = self._get_padding_h(h, w) + block_img = add_round_letterbox(img, (padding_h, padding_h, 0, 0)) + op_record["padding_1"] = {"top": padding_h, "left": 0} + return block_img, op_record + + op_record["padding_1"] = {"top": 0, "left": 0} + return img, op_record + + def _get_padding_h(self, h: int, w: int) -> int: + new_h = max(int(w / self.width_height_ratio), self.min_height) * 2 + padding_h = int(abs(new_h - h) / 2) + return padding_h + + def auto_text_det( + self, img: np.ndarray + ) -> Tuple[Optional[List[np.ndarray]], float]: + dt_boxes, det_elapse = self.text_det(img) + if dt_boxes is None or len(dt_boxes) < 1: + return None, 0.0 + + dt_boxes = self.sorted_boxes(dt_boxes) + return dt_boxes, det_elapse + + def get_crop_img_list( + self, img: np.ndarray, dt_boxes: List[np.ndarray] + ) -> List[np.ndarray]: + def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray: + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]), + ) + ) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]), + ) + ) + pts_std = np.array( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ).astype(np.float32) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, + (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC, + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + img_crop_list = [] + for box in dt_boxes: + tmp_box = copy.deepcopy(box) + img_crop = get_rotate_crop_image(img, tmp_box) + img_crop_list.append(img_crop) + return img_crop_list + + @staticmethod + def sorted_boxes(dt_boxes: np.ndarray) -> List[np.ndarray]: + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + for j in range(i, -1, -1): + if ( + abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 + and _boxes[j + 1][0][0] < _boxes[j][0][0] + ): + tmp = _boxes[j] + _boxes[j] = _boxes[j + 1] + _boxes[j + 1] = tmp + else: + break + return _boxes + + def _get_origin_points( + self, + dt_boxes: List[np.ndarray], + op_record: Dict[str, Any], + raw_h: int, + raw_w: int, + ) -> np.ndarray: + dt_boxes_array = np.array(dt_boxes).astype(np.float32) + for op in reversed(list(op_record.keys())): + v = op_record[op] + if "padding" in op: + top, left = v.get("top"), v.get("left") + dt_boxes_array[:, :, 0] -= left + dt_boxes_array[:, :, 1] -= top + elif "preprocess" in op: + ratio_h = v.get("ratio_h") + ratio_w = v.get("ratio_w") + dt_boxes_array[:, :, 0] *= ratio_w + dt_boxes_array[:, :, 1] *= ratio_h + + dt_boxes_array = np.where(dt_boxes_array < 0, 0, dt_boxes_array) + dt_boxes_array[..., 0] = np.where( + dt_boxes_array[..., 0] > raw_w, raw_w, dt_boxes_array[..., 0] + ) + dt_boxes_array[..., 1] = np.where( + dt_boxes_array[..., 1] > raw_h, raw_h, dt_boxes_array[..., 1] + ) + return dt_boxes_array + + def get_final_res( + self, + dt_boxes: Optional[List[np.ndarray]], + cls_res: Optional[List[List[Union[str, float]]]], + rec_res: Optional[List[Tuple[str, float, List[Union[str, float]]]]], + det_elapse: float, + cls_elapse: float, + rec_elapse: float, + ) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]: + if dt_boxes is None and rec_res is None and cls_res is not None: + return cls_res, [cls_elapse] + + if dt_boxes is None and rec_res is None: + return None, None + + if dt_boxes is None and rec_res is not None: + return [[res[0], res[1]] for res in rec_res], [rec_elapse] + + if dt_boxes is not None and rec_res is None: + return [box.tolist() for box in dt_boxes], [det_elapse] + + dt_boxes, rec_res = self.filter_result(dt_boxes, rec_res) + if not dt_boxes or not rec_res or len(dt_boxes) <= 0: + return None, None + + ocr_res = [[box.tolist(), *res] for box, res in zip(dt_boxes, rec_res)], [ + det_elapse, + cls_elapse, + rec_elapse, + ] + return ocr_res + + def filter_result( + self, + dt_boxes: Optional[List[np.ndarray]], + rec_res: Optional[List[Tuple[str, float]]], + ) -> Tuple[Optional[List[np.ndarray]], Optional[List[Tuple[str, float]]]]: + if dt_boxes is None or rec_res is None: + return None, None + + filter_boxes, filter_rec_res = [], [] + for box, rec_reuslt in zip(dt_boxes, rec_res): + text, score = rec_reuslt[0], rec_reuslt[1] + if float(score) >= self.text_score: + filter_boxes.append(box) + filter_rec_res.append(rec_reuslt) + + return filter_boxes, filter_rec_res + + +def main(): + args = init_args() + ocr_engine = RapidOCR(**vars(args)) + + use_det = not args.no_det + use_cls = not args.no_cls + use_rec = not args.no_rec + result, elapse_list = ocr_engine( + args.img_path, use_det=use_det, use_cls=use_cls, use_rec=use_rec, **vars(args) + ) + logger.info(result) + + # Save the recognized text to a text file in the 'results' folder + if use_det and use_rec: + boxes, txts, scores = list(zip(*result)) + + # Create the 'results' folder if it doesn't exist + results_folder = Path("results") + results_folder.mkdir(parents=True, exist_ok=True) + + # Create the file path for saving the text in 'results' folder + img_name = os.path.splitext(os.path.basename(args.img_path))[0] # Get the image name without extension + txt_file_path = results_folder / f"{img_name}.txt" # Save in 'results' folder + + # Write the recognized text to the text file + with open(txt_file_path, 'w', encoding='utf-8') as f: + for txt in txts: + f.write(txt + '\n') + + + logger.info("The recognized text has been saved in %s", txt_file_path) + + if args.print_cost: + logger.info(elapse_list) + + if args.vis_res: + vis = VisRes() + Path(args.vis_save_path).mkdir(parents=True, exist_ok=True) + save_path = Path(args.vis_save_path) / f"{Path(args.img_path).stem}_vis.png" + + if use_det and not use_cls and not use_rec: + boxes, *_ = list(zip(*result)) + vis_img = vis(args.img_path, boxes) + cv2.imwrite(str(save_path), vis_img) + logger.info("The vis result has saved in %s", save_path) + + elif use_det and use_rec: + font_path = Path(args.vis_font_path) + if not font_path.exists(): + raise FileExistsError(f"{font_path} does not exist!") + + boxes, txts, scores = list(zip(*result)) + vis_img = vis(args.img_path, boxes, txts, scores, font_path=font_path) + cv2.imwrite(str(save_path), vis_img) + logger.info("The vis result has saved in %s", save_path) + + +if __name__ == "__main__": + main() diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/ch_PP-OCRv4_det_infer.onnx b/models/ch_PP-OCRv4_det_infer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9a01aff60e08a67805951aca5fe96a87dfa78169 --- /dev/null +++ b/models/ch_PP-OCRv4_det_infer.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2a7720d45a54257208b1e13e36a8479894cb74155a5efe29462512d42f49da9 +size 4745517 diff --git a/models/ch_PP-OCRv4_rec_infer.onnx b/models/ch_PP-OCRv4_rec_infer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..949d2365a1b3713b88938865cb9099401a9873c9 --- /dev/null +++ b/models/ch_PP-OCRv4_rec_infer.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48fc40f24f6d2a207a2b1091d3437eb3cc3eb6b676dc3ef9c37384005483683b +size 10857958 diff --git a/rapidocr_ort/.gitattributes b/rapidocr_ort/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/rapidocr_ort/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/rapidocr_ort/README.md b/rapidocr_ort/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8dd7f27d99e963804dd32d6c2601d26042ad7ff7 --- /dev/null +++ b/rapidocr_ort/README.md @@ -0,0 +1,12 @@ +--- +title: Rapidocr Ort +emoji: 🏆 +colorFrom: green +colorTo: blue +sdk: gradio +sdk_version: 5.27.1 +app_file: app.py +pinned: false +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0dcf2d2bf20ba9571d12f409e426e710ed7eecf7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +pyclipper>=1.2.0 +opencv-python-headless +numpy>=1.19.5,<3.0.0 +six>=1.15.0 +Shapely>=1.7.1,!=2.0.4 # python3.12 2.0.4 bug +PyYAML +Pillow +onnxruntime>=1.7.0 +tqdm +pdf2image +rapidocr-onnxruntime +imageio +matplotlib +scikit-image +gradio \ No newline at end of file diff --git a/results/1.txt b/results/1.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/results/10.txt b/results/10.txt new file mode 100644 index 0000000000000000000000000000000000000000..6a3acd0495448649d5b382150fef780a900562d7 --- /dev/null +++ b/results/10.txt @@ -0,0 +1 @@ +123 diff --git a/results/12.txt b/results/12.txt new file mode 100644 index 0000000000000000000000000000000000000000..7fb567ab599c4b736648cb5e078d487f09bdb899 --- /dev/null +++ b/results/12.txt @@ -0,0 +1,9 @@ +Don'tDoThis! +HereIamwritingdowneveryword thatIamgoingto +sayaloudwhenIgivethepresentation.Allthetext +takingupspace on thescreen isgoing tomake itreally +difficultforpeopletolistenandtoreadeverythingat +once,even though theyarethesame thing.It'd be +betterif Ijustwentandaddedacouplekeypointson +theslideandelaborateonthembyreferringtosome +flash cards ormemorizingwhat I'mgoing tosay. diff --git a/results/2.txt b/results/2.txt new file mode 100644 index 0000000000000000000000000000000000000000..bce5a0a25003c60fa85ec826f587dac3193b4a4b --- /dev/null +++ b/results/2.txt @@ -0,0 +1 @@ +All The Best diff --git a/results/3.txt b/results/3.txt new file mode 100644 index 0000000000000000000000000000000000000000..9aec658c0b8438072702330b174ed552fe4d2b3a --- /dev/null +++ b/results/3.txt @@ -0,0 +1,17 @@ +XAVanban +Hinh anh +Tailieu +Trang web +Phat hien ngon ngir +Viet +AnhTrung(Gianthe) +Tieng Phap(Canada)Igbo +Trung(Gian the) +xin chao +你好 +Nihao +8/5.000 +口 +Guriykienphanhoi +Cacban dichdathuchien +Daluu diff --git a/results/4.txt b/results/4.txt new file mode 100644 index 0000000000000000000000000000000000000000..94a0b2782d44f915929901f84a17bc84324fcbec --- /dev/null +++ b/results/4.txt @@ -0,0 +1,6 @@ +Uyghur +Tieng Phap(Canada) +Trung(Gian the) +jgi +yakshimusiz +Guiykienphanhoi diff --git a/results/5.txt b/results/5.txt new file mode 100644 index 0000000000000000000000000000000000000000..124fa270a2a12706755e5447fe49b54f07f32bd0 --- /dev/null +++ b/results/5.txt @@ -0,0 +1,5 @@ +Thai +Th6NhiKy +Kalaallisut +Swasdi +Guiykienphan hoi diff --git a/results/6.txt b/results/6.txt new file mode 100644 index 0000000000000000000000000000000000000000..39b7c798bcb2abdda42dbc57f3c557b463de8186 --- /dev/null +++ b/results/6.txt @@ -0,0 +1,16 @@ +xAvanban +Hinh anh +Tailieu +Trangweb +Phathien ngonngur +Viet Quang Dong Trung(Phon the) +Thai Th6NhiKy Kalaallisut +012345678910 +012345678910 +012345678910 +0 +D +22/5.000 +Guini y kien phan hoi +Cacbandichda thuchien +Daluu diff --git a/results/8.txt b/results/8.txt new file mode 100644 index 0000000000000000000000000000000000000000..129a793c804b51e9f4066c58fd9af567233fd706 --- /dev/null +++ b/results/8.txt @@ -0,0 +1 @@ +8 diff --git a/results/9.txt b/results/9.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ab5d0e4e143979538afbaeb7049501e80fa559e --- /dev/null +++ b/results/9.txt @@ -0,0 +1 @@ +O diff --git a/results/test.txt b/results/test.txt new file mode 100644 index 0000000000000000000000000000000000000000..22d5e2db6282e6139ec211ce25996aec5b906332 --- /dev/null +++ b/results/test.txt @@ -0,0 +1,7 @@ +格格打 +这时我看见他的背影, +我的泪很快地流下来了。 +我赶紧拭干了泪, +怕他看见, +也怕别人看见。 +《背影》 diff --git a/tmp/uploaded_1745856378.pdf b/tmp/uploaded_1745856378.pdf new file mode 100644 index 0000000000000000000000000000000000000000..21828a8732d2733c41b9d99daef1958d5c1a1ac3 --- /dev/null +++ b/tmp/uploaded_1745856378.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa4276cf070aa6c95f550d81ce5bf1ef36e4c8ecf29088acdab98ca2e4f757fc +size 1553257 diff --git a/tmp/uploaded_1745860914.pdf b/tmp/uploaded_1745860914.pdf new file mode 100644 index 0000000000000000000000000000000000000000..26c6b9f86a2b222b57d6445257ffa0822825838d Binary files /dev/null and b/tmp/uploaded_1745860914.pdf differ diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..150da63782e7fc3ddf70a05238dd68ed55a2e491 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,20 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from pathlib import Path +from typing import Dict, Union + +import yaml + +from .infer_engine import OrtInferSession +from .load_image import LoadImage, LoadImageError +from .logger import get_logger +from .parse_parameters import UpdateParameters, init_args, update_model_path +from .process_img import add_round_letterbox, increase_min_side, reduce_max_side +from .vis_res import VisRes + + +def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]: + with open(yaml_path, "rb") as f: + data = yaml.load(f, Loader=yaml.Loader) + return data diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ba5d304aec8cc25b474e7de186cb982e8d3b9d2 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-311.pyc b/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b1f7bb17391afbc30651c50ef839dc750fb6ae Binary files /dev/null and b/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/utils/__pycache__/__init__.cpython-312.pyc b/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d397553eaabaabadcff3109487c2f9652438fd52 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/utils/__pycache__/__init__.cpython-313.pyc b/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36b6ec61c3141b4b66d97419c0aba6b4f3eb6b76 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/utils/__pycache__/infer_engine.cpython-310.pyc b/utils/__pycache__/infer_engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd93b4b158cf70548c25ed4356ce4b96f99815ed Binary files /dev/null and b/utils/__pycache__/infer_engine.cpython-310.pyc differ diff --git a/utils/__pycache__/infer_engine.cpython-311.pyc b/utils/__pycache__/infer_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17bb674c5e806cd220b135e82da075374ed2567e Binary files /dev/null and b/utils/__pycache__/infer_engine.cpython-311.pyc differ diff --git a/utils/__pycache__/infer_engine.cpython-312.pyc b/utils/__pycache__/infer_engine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1887faac1c91fa82bdb97840eff7312ec60469b5 Binary files /dev/null and b/utils/__pycache__/infer_engine.cpython-312.pyc differ diff --git a/utils/__pycache__/infer_engine.cpython-313.pyc b/utils/__pycache__/infer_engine.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3b03eec9f7585220f2a8e6346bf87e181b42b89 Binary files /dev/null and b/utils/__pycache__/infer_engine.cpython-313.pyc differ diff --git a/utils/__pycache__/load_image.cpython-310.pyc b/utils/__pycache__/load_image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d70bbb9867da1a17822c46d299801ca8cf5a847b Binary files /dev/null and b/utils/__pycache__/load_image.cpython-310.pyc differ diff --git a/utils/__pycache__/load_image.cpython-311.pyc b/utils/__pycache__/load_image.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf8578d28b74a10264b3ddab7424a21d09664d82 Binary files /dev/null and b/utils/__pycache__/load_image.cpython-311.pyc differ diff --git a/utils/__pycache__/load_image.cpython-312.pyc b/utils/__pycache__/load_image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a976d335ed96b2bb6557997337763975b65d7f8 Binary files /dev/null and b/utils/__pycache__/load_image.cpython-312.pyc differ diff --git a/utils/__pycache__/load_image.cpython-313.pyc b/utils/__pycache__/load_image.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b332edc8d0ef683d65bb4526ebbc05f85d70291 Binary files /dev/null and b/utils/__pycache__/load_image.cpython-313.pyc differ diff --git a/utils/__pycache__/logger.cpython-310.pyc b/utils/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca9c2470b2b6f9b067b0a6d6b193feafe9fec2f5 Binary files /dev/null and b/utils/__pycache__/logger.cpython-310.pyc differ diff --git a/utils/__pycache__/logger.cpython-311.pyc b/utils/__pycache__/logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a061c3df276ddd5182368285abbd2e7dd5bb4b5 Binary files /dev/null and b/utils/__pycache__/logger.cpython-311.pyc differ diff --git a/utils/__pycache__/logger.cpython-312.pyc b/utils/__pycache__/logger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5430320c8ae35046831a1acb8b68317b501b5492 Binary files /dev/null and b/utils/__pycache__/logger.cpython-312.pyc differ diff --git a/utils/__pycache__/logger.cpython-313.pyc b/utils/__pycache__/logger.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59f9473e034f45e83506b012839c086c431542a9 Binary files /dev/null and b/utils/__pycache__/logger.cpython-313.pyc differ diff --git a/utils/__pycache__/parse_parameters.cpython-310.pyc b/utils/__pycache__/parse_parameters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69273c53eb9ef4cbff0fb0db824e668b71c48c2f Binary files /dev/null and b/utils/__pycache__/parse_parameters.cpython-310.pyc differ diff --git a/utils/__pycache__/parse_parameters.cpython-311.pyc b/utils/__pycache__/parse_parameters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6948afa4b414b540f7e60b879a0fbecb316ce63 Binary files /dev/null and b/utils/__pycache__/parse_parameters.cpython-311.pyc differ diff --git a/utils/__pycache__/parse_parameters.cpython-312.pyc b/utils/__pycache__/parse_parameters.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2df0b8e06f1cd9283be8063030071df83fd7c6b Binary files /dev/null and b/utils/__pycache__/parse_parameters.cpython-312.pyc differ diff --git a/utils/__pycache__/parse_parameters.cpython-313.pyc b/utils/__pycache__/parse_parameters.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..008284fae87412f2ea20e7fc8cab238a5427a601 Binary files /dev/null and b/utils/__pycache__/parse_parameters.cpython-313.pyc differ diff --git a/utils/__pycache__/process_img.cpython-310.pyc b/utils/__pycache__/process_img.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09964deb675956013ef0fc93fcfc7406ccd0b07a Binary files /dev/null and b/utils/__pycache__/process_img.cpython-310.pyc differ diff --git a/utils/__pycache__/process_img.cpython-311.pyc b/utils/__pycache__/process_img.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac87f58d5a4b2979a2fb4efd9c0b7381adeab6f7 Binary files /dev/null and b/utils/__pycache__/process_img.cpython-311.pyc differ diff --git a/utils/__pycache__/process_img.cpython-312.pyc b/utils/__pycache__/process_img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62b8eecf5653c6422ca01df8870916997e3435b5 Binary files /dev/null and b/utils/__pycache__/process_img.cpython-312.pyc differ diff --git a/utils/__pycache__/process_img.cpython-313.pyc b/utils/__pycache__/process_img.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86924c1513ab933f382aed812ca12bc5549a872f Binary files /dev/null and b/utils/__pycache__/process_img.cpython-313.pyc differ diff --git a/utils/__pycache__/vis_res.cpython-310.pyc b/utils/__pycache__/vis_res.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dbf57b5ca42bbcb3d32e5b3e1523912808fc449 Binary files /dev/null and b/utils/__pycache__/vis_res.cpython-310.pyc differ diff --git a/utils/__pycache__/vis_res.cpython-311.pyc b/utils/__pycache__/vis_res.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33534b24025f70d18b18236d1a8cd9492686f9b4 Binary files /dev/null and b/utils/__pycache__/vis_res.cpython-311.pyc differ diff --git a/utils/__pycache__/vis_res.cpython-312.pyc b/utils/__pycache__/vis_res.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30e577b10a21ebcf03314f5ad1c6f4911389312a Binary files /dev/null and b/utils/__pycache__/vis_res.cpython-312.pyc differ diff --git a/utils/__pycache__/vis_res.cpython-313.pyc b/utils/__pycache__/vis_res.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a57e0d57dfabbb947de7e43e3bdb29dcc747906 Binary files /dev/null and b/utils/__pycache__/vis_res.cpython-313.pyc differ diff --git a/utils/infer_engine.py b/utils/infer_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1fd351d7d79248c67c795eb9c1db16387bf292 --- /dev/null +++ b/utils/infer_engine.py @@ -0,0 +1,231 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import os +import platform +import traceback +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) + +from .logger import get_logger + + +class EP(Enum): + CPU_EP = "CPUExecutionProvider" + CUDA_EP = "CUDAExecutionProvider" + DIRECTML_EP = "DmlExecutionProvider" + + +class OrtInferSession: + def __init__(self, config: Dict[str, Any]): + self.logger = get_logger("OrtInferSession") + + model_path = config.get("model_path", None) + self._verify_model(model_path) + + self.cfg_use_cuda = config.get("use_cuda", None) + self.cfg_use_dml = config.get("use_dml", None) + + self.had_providers: List[str] = get_available_providers() + EP_list = self._get_ep_list() + + sess_opt = self._init_sess_opts(config) + self.session = InferenceSession( + model_path, + sess_options=sess_opt, + providers=EP_list, + ) + self._verify_providers() + + @staticmethod + def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: + sess_opt = SessionOptions() + sess_opt.log_severity_level = 4 + sess_opt.enable_cpu_mem_arena = False + sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + cpu_nums = os.cpu_count() + intra_op_num_threads = config.get("intra_op_num_threads", -1) + if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: + sess_opt.intra_op_num_threads = intra_op_num_threads + + inter_op_num_threads = config.get("inter_op_num_threads", -1) + if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: + sess_opt.inter_op_num_threads = inter_op_num_threads + + return sess_opt + + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: + cpu_provider_opts = { + "arena_extend_strategy": "kSameAsRequested", + } + EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] + + cuda_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + } + self.use_cuda = self._check_cuda() + if self.use_cuda: + EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) + + self.use_directml = self._check_dml() + if self.use_directml: + self.logger.info( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) + EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) + return EP_list + + def _check_cuda(self) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.CUDA_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") + self.logger.info( + "(For reference only) If you want to use GPU acceleration, you must do:" + ) + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + EP.CUDA_EP.value, + ) + return False + + def _check_dml(self) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", + cur_os, + self.had_providers[0], + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", + cur_window_version, + self.had_providers[0], + ) + return False + + if EP.DIRECTML_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.DIRECTML_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + EP.DIRECTML_EP.value, + ) + return False + + def _verify_providers(self): + session_providers = self.session.get_providers() + first_provider = session_providers[0] + + if self.use_cuda and first_provider != EP.CUDA_EP.value: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + EP.CUDA_EP.value, + first_provider, + ) + + if self.use_directml and first_provider != EP.DIRECTML_EP.value: + self.logger.warning( + "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", + EP.DIRECTML_EP.value, + first_provider, + ) + + def __call__(self, input_content: np.ndarray) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), [input_content])) + try: + return self.session.run(self.get_output_names(), input_dict) + except Exception as e: + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e + + def get_input_names(self) -> List[str]: + return [v.name for v in self.session.get_inputs()] + + def get_output_names(self) -> List[str]: + return [v.name for v in self.session.get_outputs()] + + def get_character_list(self, key: str = "character") -> List[str]: + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): + return True + return False + + @staticmethod + def _verify_model(model_path: Union[str, Path, None]): + if model_path is None: + raise ValueError("model_path is None!") + + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +class ONNXRuntimeError(Exception): + pass diff --git a/utils/load_image.py b/utils/load_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f34b549fe374df24e485d39c774389e3f0169fc1 --- /dev/null +++ b/utils/load_image.py @@ -0,0 +1,123 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from io import BytesIO +from pathlib import Path +from typing import Any, Union + +import cv2 +import numpy as np +from PIL import Image, UnidentifiedImageError + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class LoadImage: + def __init__(self): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + origin_img_type = type(img) + img = self.load_img(img) + img = self.convert_img(img, origin_img_type) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = self.img_to_ndarray(Image.open(img)) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = self.img_to_ndarray(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + if isinstance(img, Image.Image): + return self.img_to_ndarray(img) + + raise LoadImageError(f"{type(img)} is not supported!") + + def img_to_ndarray(self, img: Image.Image) -> np.ndarray: + if img.mode == "1": + img = img.convert("L") + return np.array(img) + return np.array(img) + + def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 3: + if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if channel == 4: + return self.cvt_four_to_three(img) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + + mean_color = np.mean(new_img) + if mean_color <= 0.0: + new_img = cv2.add(new_img, not_a) + else: + new_img = cv2.bitwise_not(new_img) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..66522c489a5f59c2d11d051a4d02d6cdeff05615 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,21 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import logging +from functools import lru_cache + + +@lru_cache(maxsize=32) +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/utils/parse_parameters.py b/utils/parse_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..0614412faf3184345f3999f6340184ec1509694f --- /dev/null +++ b/utils/parse_parameters.py @@ -0,0 +1,197 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import argparse +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from PIL import Image + +root_dir = Path(__file__).resolve().parent.parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +def update_model_path(config: Dict[str, Any]) -> Dict[str, Any]: + key = "model_path" + config["Det"][key] = str(root_dir / config["Det"][key]) + config["Rec"][key] = str(root_dir / config["Rec"][key]) + # config["Cls"][key] = str(root_dir / config["Cls"][key]) + return config + + +def init_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-img", "--img_path", type=str, default=None, required=True) + parser.add_argument("-p", "--print_cost", action="store_true", default=False) + + global_group = parser.add_argument_group(title="Global") + global_group.add_argument("--text_score", type=float, default=0.5) + + global_group.add_argument("--no_det", action="store_true", default=False) + global_group.add_argument("--no_cls", action="store_true", default=False) + global_group.add_argument("--no_rec", action="store_true", default=False) + + global_group.add_argument("--print_verbose", action="store_true", default=False) + global_group.add_argument("--min_height", type=int, default=30) + global_group.add_argument("--width_height_ratio", type=int, default=8) + global_group.add_argument("--max_side_len", type=int, default=2000) + global_group.add_argument("--min_side_len", type=int, default=30) + global_group.add_argument("--return_word_box", action="store_true", default=False) + + global_group.add_argument("--intra_op_num_threads", type=int, default=-1) + global_group.add_argument("--inter_op_num_threads", type=int, default=-1) + + det_group = parser.add_argument_group(title="Det") + det_group.add_argument("--det_use_cuda", action="store_true", default=False) + det_group.add_argument("--det_use_dml", action="store_true", default=False) + det_group.add_argument("--det_model_path", type=str, default=None) + det_group.add_argument("--det_limit_side_len", type=float, default=736) + det_group.add_argument( + "--det_limit_type", type=str, default="min", choices=["max", "min"] + ) + det_group.add_argument("--det_thresh", type=float, default=0.3) + det_group.add_argument("--det_box_thresh", type=float, default=0.5) + det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) + det_group.add_argument( + "--det_donot_use_dilation", action="store_true", default=False + ) + det_group.add_argument( + "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] + ) + + cls_group = parser.add_argument_group(title="Cls") + cls_group.add_argument("--cls_use_cuda", action="store_true", default=False) + cls_group.add_argument("--cls_use_dml", action="store_true", default=False) + cls_group.add_argument("--cls_model_path", type=str, default=None) + cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) + cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) + cls_group.add_argument("--cls_batch_num", type=int, default=6) + cls_group.add_argument("--cls_thresh", type=float, default=0.9) + + rec_group = parser.add_argument_group(title="Rec") + rec_group.add_argument("--rec_use_cuda", action="store_true", default=False) + rec_group.add_argument("--rec_use_dml", action="store_true", default=False) + rec_group.add_argument("--rec_model_path", type=str, default=None) + rec_group.add_argument("--rec_keys_path", type=str, default=None) + rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) + rec_group.add_argument("--rec_batch_num", type=int, default=6) + + vis_group = parser.add_argument_group(title="Visual Result") + vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False) + vis_group.add_argument( + "--vis_font_path", + type=str, + default=None, + help="When -vis is True, the font_path must have value.", + ) + vis_group.add_argument( + "--vis_save_path", + type=str, + default=".", + help="The directory of saving the vis image.", + ) + + args = parser.parse_args() + return args + + +class UpdateParameters: + def __init__(self) -> None: + pass + + def parse_kwargs(self, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} + for k, v in kwargs.items(): + if k.startswith("det"): + k = k.split("det_")[1] + if k == "donot_use_dilation": + k = "use_dilation" + v = not v + + det_dict[k] = v + elif k.startswith("cls"): + cls_dict[k] = v + elif k.startswith("rec"): + rec_dict[k] = v + else: + global_dict[k] = v + return global_dict, det_dict, cls_dict, rec_dict + + def __call__(self, config, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) + new_config = { + "Global": self.update_global_params(config["Global"], global_dict), + "Det": self.update_params( + config["Det"], + det_dict, + "det_", + ["det_model_path", "det_use_cuda", "det_use_dml"], + ), + # "Cls": self.update_params( + # config["Cls"], + # cls_dict, + # "cls_", + # ["cls_label_list", "cls_model_path", "cls_use_cuda", "cls_use_dml"], + # ), + "Rec": self.update_params( + config["Rec"], + rec_dict, + "rec_", + ["rec_model_path", "rec_use_cuda", "rec_use_dml"], + ), + } + + update_params = ["intra_op_num_threads", "inter_op_num_threads"] + new_config = self.update_global_to_module( + # config, update_params, src="Global", dsts=["Det", "Cls", "Rec"] + config, update_params, src="Global", dsts=["Det", "Rec"] + ) + return new_config + + def update_global_to_module( + self, config, params: List[str], src: str, dsts: List[str] + ): + for dst in dsts: + for param in params: + config[dst].update({param: config[src][param]}) + return config + + def update_global_params(self, config, global_dict): + if global_dict: + config.update(global_dict) + return config + + def update_params( + self, + config, + param_dict: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ): + if not param_dict: + return config + + filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix) + model_path = filter_dict.get("model_path", None) + if not model_path: + filter_dict["model_path"] = str(root_dir / config["model_path"]) + + config.update(filter_dict) + return config + + @staticmethod + def remove_prefix( + config: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ) -> Dict[str, str]: + if not need_remove_prefix: + return config + + new_rec_dict = {} + for k, v in config.items(): + if k in need_remove_prefix: + k = k.split(prefix)[1] + new_rec_dict[k] = v + return new_rec_dict diff --git a/utils/process_img.py b/utils/process_img.py new file mode 100644 index 0000000000000000000000000000000000000000..7d3e19460ca1cf9cf88fb619a180e6f459bd0de0 --- /dev/null +++ b/utils/process_img.py @@ -0,0 +1,87 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from typing import Tuple + +import cv2 +import numpy as np + + +def reduce_max_side( + img: np.ndarray, max_side_len: int = 2000 +) -> Tuple[np.ndarray, float, float]: + h, w = img.shape[:2] + + ratio = 1.0 + if max(h, w) > max_side_len: + if h > w: + ratio = float(max_side_len) / h + else: + ratio = float(max_side_len) / w + + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = int(round(resize_h / 32) * 32) + resize_w = int(round(resize_w / 32) * 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + raise ResizeImgError("resize_w or resize_h is less than or equal to 0") + img = cv2.resize(img, (resize_w, resize_h)) + except Exception as exc: + raise ResizeImgError() from exc + + ratio_h = h / resize_h + ratio_w = w / resize_w + return img, ratio_h, ratio_w + + +def increase_min_side( + img: np.ndarray, min_side_len: int = 30 +) -> Tuple[np.ndarray, float, float]: + h, w = img.shape[:2] + + ratio = 1.0 + if min(h, w) < min_side_len: + if h < w: + ratio = float(min_side_len) / h + else: + ratio = float(min_side_len) / w + + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = int(round(resize_h / 32) * 32) + resize_w = int(round(resize_w / 32) * 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + raise ResizeImgError("resize_w or resize_h is less than or equal to 0") + img = cv2.resize(img, (resize_w, resize_h)) + except Exception as exc: + raise ResizeImgError() from exc + + ratio_h = h / resize_h + ratio_w = w / resize_w + return img, ratio_h, ratio_w + + +def add_round_letterbox( + img: np.ndarray, + padding_tuple: Tuple[int, int, int, int], +) -> np.ndarray: + padded_img = cv2.copyMakeBorder( + img, + padding_tuple[0], + padding_tuple[1], + padding_tuple[2], + padding_tuple[3], + cv2.BORDER_CONSTANT, + value=(0, 0, 0), + ) + return padded_img + + +class ResizeImgError(Exception): + pass diff --git a/utils/vis_res.py b/utils/vis_res.py new file mode 100644 index 0000000000000000000000000000000000000000..bd18031f1a27073b44e311ff24a746f5112f56d6 --- /dev/null +++ b/utils/vis_res.py @@ -0,0 +1,143 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import math +import random +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from .load_image import LoadImage + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class VisRes: + def __init__(self, text_score: float = 0.5): + self.text_score = text_score + self.load_img = LoadImage() + + def __call__( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Optional[Union[List[str], Tuple[str]]] = None, + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + if txts is None: + return self.draw_dt_boxes(img_content, dt_boxes) + return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) + + def draw_dt_boxes(self, img_content: InputType, dt_boxes: np.ndarray) -> np.ndarray: + img = self.load_img(img_content) + + for idx, box in enumerate(dt_boxes): + color = self.get_random_color() + + points = np.array(box) + cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1) + + start_point = round(points[0][0]), round(points[0][1]) + cv2.putText( + img, f"{idx}", start_point, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 3 + ) + return img + + def draw_ocr_box_txt( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Union[List[str], Tuple[str]], + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + font_path = self.get_font_path(font_path) + + image = Image.fromarray(self.load_img(img_content)) + h, w = image.height, image.width + if image.mode == "L": + image = image.convert("RGB") + + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(dt_boxes, txts)): + if scores is not None and float(scores[idx]) < self.text_score: + continue + + color = self.get_random_color() + + box_list = np.array(box).reshape(8).tolist() + draw_left.polygon(box_list, fill=color) + draw_right.polygon(box_list, outline=color) + + box_height = self.get_box_height(box) + box_width = self.get_box_width(box) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + cur_y = box[0][1] + + for c in txt: + draw_right.text( + (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font + ) + cur_y += self.get_char_size(font, c) + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + return np.array(img_show) + + @staticmethod + def get_font_path(font_path: Optional[Union[str, Path]] = None) -> str: + if font_path is None or not Path(font_path).exists(): + raise FileNotFoundError( + f"The {font_path} does not exists! \n" + f"You could download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" + ) + return str(font_path) + + @staticmethod + def get_random_color() -> Tuple[int, int, int]: + return ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + + @staticmethod + def get_box_height(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) + + @staticmethod + def get_box_width(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) + + @staticmethod + def get_char_size(font, char_str: str) -> float: + # compatible with Pillow v9 and v10. + if hasattr(font, "getsize"): + get_size_func = getattr(font, "getsize") + return get_size_func(char_str)[1] + + if hasattr(font, "getlength"): + get_size_func = getattr(font, "getlength") + return get_size_func(char_str) + + raise ValueError( + "The Pillow ImageFont instance has not getsize or getlength func." + )