pmkhanh7890 commited on
Commit
e420e5d
·
1 Parent(s): 95ded62

push code

Browse files
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pdiparams filter=lfs diff=lfs merge=lfs -text
2
+ *.pdmodel filter=lfs diff=lfs merge=lfs -text
3
+ *.otf filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # virtual environment
2
+ .venv
3
+ common/cpp
4
+ common/cpp_gapi
5
+ sample_videos/*
6
+ models/vietocr/*
7
+ models/soundbar/*
8
+ results.json
9
+ output/*
10
+ models/soundbar_detection_yolov7/*
11
+ # Byte-compiled / optimized / DLL files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+ # C extensions
16
+ *.so
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+ # source/yolov7/*
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+ # Scrapy stuff:
68
+ .scrapy
69
+ # Sphinx documentation
70
+ docs/_build/
71
+ # PyBuilder
72
+ target/
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+ # IPython
76
+ profile_default/
77
+ ipython_config.py
78
+ # pyenv
79
+ .python-version
80
+ # celery beat schedule file
81
+ celerybeat-schedule
82
+ # SageMath parsed files
83
+ *.sage.py
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+ # Spyder project settings
93
+ .spyderproject
94
+ .spyproject
95
+ # Rope project settings
96
+ .ropeproject
97
+ # mkdocs documentation
98
+ /site
99
+ # mypy
100
+ .mypy_cache/
101
+ .dmypy.json
102
+ dmypy.json
103
+ # Pyre type checker
104
+ .pyre/
105
+ # custom
106
+ old/
107
+ notebook/
108
+ static/images/face
109
+ static/images/celeba
110
+ celeba_pics.gz
111
+ celeba_vecs.gz
112
+ face_pics.gz
113
+ face_vecs.gz
.pre-commit-config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v4.5.0
6
+ hooks:
7
+ #- id: check-added-large-files
8
+ - id: fix-byte-order-marker
9
+ - id: check-case-conflict
10
+ - id: check-json
11
+ - id: check-yaml
12
+ args: ['--unsafe']
13
+ - id: detect-aws-credentials
14
+ args: [--allow-missing-credentials]
15
+ - id: detect-private-key
16
+ - id: end-of-file-fixer
17
+ - id: mixed-line-ending
18
+ - id: trailing-whitespace
19
+ - repo: https://github.com/asottile/add-trailing-comma
20
+ rev: v3.1.0
21
+ hooks:
22
+ - id: add-trailing-comma
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ name: isort (python)
28
+ args: [--settings-path=pyproject.toml]
29
+ - id: isort
30
+ name: isort (cython)
31
+ types: [cython]
32
+ - id: isort
33
+ name: isort (pyi)
34
+ types: [pyi]
35
+ - repo: https://github.com/psf/black
36
+ rev: 23.11.0
37
+ hooks:
38
+ - id: black
39
+ args: [--config=pyproject.toml]
40
+ - repo: https://github.com/pycqa/flake8.git
41
+ rev: 6.1.0
42
+ hooks:
43
+ - id: flake8
44
+ args: [--ignore, "E203", --max-line-length, "79"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.6.1
47
+ hooks:
48
+ - id: nbstripout
49
+ - repo: https://github.com/asottile/pyupgrade
50
+ rev: v3.15.0
51
+ hooks:
52
+ - id: pyupgrade
53
+ args: [--py36-plus]
.pre-commit-setting.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ [tool.black]
4
+ line-length = 79
5
+ include = '\.pyi?$'
6
+ exclude = '''
7
+ /(
8
+ \.git
9
+ | \.idea
10
+ | \.pytest_cache
11
+ | \.tox
12
+ | \.venv
13
+ | _build
14
+ | buck-out
15
+ | build
16
+ | dist
17
+ )/
18
+ '''
19
+
20
+ [flake8]
21
+ ignore = E203
22
+ max-line-length = 79
README.md CHANGED
@@ -1,14 +1,59 @@
1
- ---
2
- title: Kleverocr
3
- emoji: 🏆
4
- colorFrom: purple
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.1.2
8
- app_file: app.py
9
- pinned: false
10
- models:
11
- - kleverocr/
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">Project title</h1>
2
+
3
+ [Demo on Streamlit](https://khanhphantt-demo.streamlitapp.com/)
4
+
5
+ Project description
6
+
7
+ <p align="center"><img src="data/images/out.jpeg" width="700" height="400"></p>
8
+
9
+
10
+ ### :file_folder: Dataset
11
+ This dataset consists of __XXX samples__ (X classes):
12
+ * __class_1: YYY samples__
13
+ * __class_2: YYY samples__
14
+
15
+ ## Installation
16
+ 1. Clone the repo
17
+ ```
18
+ $ git clone https://github.com/khanhphantt/project
19
+ ```
20
+
21
+ 2. Change your directory to the cloned repo
22
+ ```
23
+ $ cd Face-Mask-Detection
24
+ ```
25
+
26
+ 3. Create a Python virtual environment named '.venv' and activate it
27
+ ```
28
+ $ python -m venv .venv
29
+ $ .venv/bin/activate
30
+ ```
31
+
32
+ 4. Install the libraries required
33
+ ```
34
+ $ pip install -r requirements.txt
35
+ ```
36
+
37
+ ## Working
38
+
39
+ 1. Train data:
40
+ ```
41
+ $ python3
42
+ ```
43
+
44
+ 2. Detect face masks in an image:
45
+ ```
46
+ $ python3
47
+ ```
48
+
49
+ 3. Detect face masks in real-time video (webcam):
50
+ ```
51
+ $ python3
52
+ ```
53
+
54
+ ## Streamlit app
55
+
56
+ Face Mask Detector webapp using Tensorflow & Streamlit:
57
+ ```
58
+ $ streamlit run app.py
59
+ ```
_config.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ theme: jekyll-theme-modernist
app.py CHANGED
@@ -15,11 +15,12 @@ from src.visualization import visualize_result
15
  print("[INFO] Loaded recognition model")
16
 
17
 
18
- def do_ocr(inp):
19
- print(type(inp))
20
  # img = cv2.imread(inp, cv2.IMREAD_GRAYSCALE)
21
  # image = cv2.imread(inp, cv2.IMREAD_GRAYSCALE)
22
- image = cv2.cvtColor(inp, cv2.COLOR_BGR2GRAY)
 
23
  """
24
  # Recognize all text with boxes and scores
25
  cls = False: to improve the performance
@@ -36,7 +37,7 @@ def do_ocr(inp):
36
  print(f"OCR time: {time_ocr - time_start}")
37
  print(f"post time: {time.time() - time_ocr}")
38
  time_total = time.time() - time_start
39
- img_box, img_text = visualize_result(result, inp)
40
  return img_text, img_box, time_total
41
 
42
 
 
15
  print("[INFO] Loaded recognition model")
16
 
17
 
18
+ def do_ocr(image):
19
+ print(type(image))
20
  # img = cv2.imread(inp, cv2.IMREAD_GRAYSCALE)
21
  # image = cv2.imread(inp, cv2.IMREAD_GRAYSCALE)
22
+ visualized_image = image
23
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
24
  """
25
  # Recognize all text with boxes and scores
26
  cls = False: to improve the performance
 
37
  print(f"OCR time: {time_ocr - time_start}")
38
  print(f"post time: {time.time() - time_ocr}")
39
  time_total = time.time() - time_start
40
+ img_box, img_text = visualize_result(result, visualized_image)
41
  return img_text, img_box, time_total
42
 
43
 
ocr.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from os import listdir
3
+ from os.path import (
4
+ isfile,
5
+ join,
6
+ )
7
+
8
+ import cv2
9
+
10
+ from src.postprocessing import postprocess_result
11
+ from src.settings import (
12
+ OCR_JA,
13
+ OCR_ML,
14
+ )
15
+ from src.visualization import visualize_result
16
+
17
+ path = "data/"
18
+
19
+
20
+ def paddleOCR(path):
21
+ """
22
+ perform ocr
23
+ args:
24
+ path(array): path to input folder
25
+ return(str):
26
+ text in markdown format
27
+ """
28
+ imgs = [f for f in listdir(path) if isfile(join(path, f))]
29
+ for img_file in imgs:
30
+ if not img_file.endswith("JA_0.png"):
31
+ continue
32
+
33
+ img_path = join(path, img_file)
34
+ out_path = join(path, "output", img_file)
35
+ print(f"Process {img_path}")
36
+ image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
37
+ """
38
+ # Recognize all text with boxes and scores
39
+ cls = False: to improve the performance
40
+ recognize text only from -90 to 90 degree
41
+ bboxes = OCR.ocr(image_crop, cls=False, det=True, rec=False)
42
+ result = recognize_text(image_crop, bboxes)
43
+ """
44
+ time_start = time.time()
45
+ result = OCR_JA.ocr(image, cls=True, det=True, rec=True)
46
+ time_ocr = time.time()
47
+ result = postprocess_result(image, result, OCR_ML)
48
+ print(f"OCR time: {time_ocr - time_start}")
49
+ print(f"post time: {time.time() - time_ocr}")
50
+ visualize_result(result, img_path, out_path)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ paddleOCR(path)
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 79
3
+ include = '\.pyi?$'
4
+ exclude = '''
5
+ /(
6
+ \.git
7
+ | \.idea
8
+ | \.pytest_cache
9
+ | \.tox
10
+ | \.venv
11
+ | _build
12
+ | buck-out
13
+ | build
14
+ | dist
15
+ )/
16
+ '''
17
+
18
+ [tool.isort]
19
+ profile = "black"
20
+ force_grid_wrap=2
21
+ multi_line_output=3
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ paddlepaddle
2
+ paddleocr
result.jpg ADDED
src/__init__.py ADDED
File without changes
src/postprocessing.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.settings import RECOGNITION_THRESHOLD
2
+ from src.utilities import crop_image
3
+
4
+
5
+ def postprocess_result(image, result, OCR):
6
+ """
7
+ Post-processing steps to improve the results
8
+ args:
9
+ image(Image|array): RGB image
10
+ result(list): boxes with shape(N, 4, 2), text and score
11
+ return(Image|array):
12
+ updated result
13
+ """
14
+ new_result = []
15
+ for line in result[0]:
16
+ if line[1][1] < RECOGNITION_THRESHOLD:
17
+ """
18
+ boxes = line[0], txts = line[1][0], scores = line[1][1]
19
+ """
20
+ line[1] = recognize_text_by_multilanguage(image, line, OCR)
21
+ new_result.append(line)
22
+
23
+ return [new_result]
24
+
25
+
26
+ def recognize_text_by_multilanguage(image, line, OCR):
27
+ """
28
+ Do recognition again on the text having low recognition score.
29
+ args:
30
+ image(Image|array): RGB image
31
+ result(list): boxes with shape(N, 4, 2), text and score
32
+ return(Image|array):
33
+ updated result
34
+ """
35
+ box = line[0]
36
+ txt = line[1][0]
37
+ score = line[1][1]
38
+ cropped_image = crop_image(image, box)
39
+ result = OCR.ocr(cropped_image, cls=True, det=False, rec=True)
40
+ if result[0][0][1] > score:
41
+ print(f"[{score}]{txt} -----> [{result[0][0][1]}]{result[0][0][0]}")
42
+ txt = result[0][0][0]
43
+ score = result[0][0][1]
44
+ else:
45
+ print(f"[{score}]{txt} --X--> [{result[0][0][1]}]{result[0][0][0]}")
46
+
47
+ return (txt, score)
src/settings.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+
4
+ from paddleocr import PaddleOCR
5
+
6
+ # check GPU
7
+ # gpu = True if torch.cuda.is_available() else False
8
+
9
+ RECOGNITION_THRESHOLD = 0.90
10
+ FONTPATH = "fonts/NotoSerifJP-SemiBold.otf"
11
+ """ Models list
12
+ # Detection:
13
+ detection_ch_PP-OCRv3: Original lightweight model,
14
+ supporting Chinese, English, multilingual text detection
15
+ detection_ml_PP-OCRv3: Original lightweight detection model,
16
+ supporting English, multilingual text detection
17
+
18
+ # Recognition:
19
+ recognition_ch_PP-OCRv3: [New] Original lightweight model,
20
+ supporting Chinese, English, multilingual text recognition.
21
+ No dict.
22
+ recognition_japan_PP-OCRv3_rec: Lightweight model for Japanese recognition
23
+ Need dict: ppocr/utils/dict/japan_dict.txt
24
+
25
+ """
26
+
27
+
28
+ OCR_JA = PaddleOCR(
29
+ det_model_dir="models/detection/ch_PP-OCRv4_server_det/",
30
+ rec_model_dir="models/recognition/japan_PP-OCRv4_rec/",
31
+ rec_char_dict_path="models/char_dict/japan_dict.txt",
32
+ cls_model_dir="models/cls/ch_ppocr_mobile_v2.0_cls",
33
+ use_angle_cls=True,
34
+ )
35
+
36
+ OCR_ML = PaddleOCR(
37
+ det_model_dir="models/detection/ch_PP-OCRv4_server_det/",
38
+ rec_model_dir="models/recognition/ch_PP-OCRv4_rec/",
39
+ # rec_char_dict_path="models/char_dict/japan_dict.txt",
40
+ cls_model_dir="models/cls/ch_ppocr_mobile_v2.0_cls",
41
+ use_angle_cls=True,
42
+ )
43
+
44
+ # ocr = PaddleOCR(use_angle_cls=True, lang='japan')
45
+
46
+ """Define logging"""
47
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
48
+ # logger_value = logging.getLogger("Value")
49
+ # logger_value.disabled()
50
+ # logger_date = logging.getLogger("Date")
51
+ # logger_date.disable()
52
+ """
53
+ PaddleOCR logger
54
+ ppstrcuture//table/predict_table.py
55
+ dt_boxes num: line 81
56
+ rec_res num: line 93
57
+ tools/infer/predict_system.py
58
+ dt_boxes num: line 71
59
+ rec_res num: line 90
60
+ """
src/utilities.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def crop_image(image, bbox):
2
+ """
3
+ crop the text area (got from text detection) to recognize
4
+ args:
5
+ image(Image|array): RGB image
6
+ bbox(list): boxes with shape(N, 4, 2)
7
+ return(Image|array):
8
+ cropped image
9
+ """
10
+ top = int(min(bbox[0][1], bbox[1][1]))
11
+ bot = int(max(bbox[2][1], bbox[3][1]))
12
+ left = int(min(bbox[0][0], bbox[3][0]))
13
+ right = int(max(bbox[1][0], bbox[2][0]))
14
+ return image[top:bot, left:right]
src/visualization.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import string
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import (
7
+ Image,
8
+ ImageDraw,
9
+ ImageFont,
10
+ )
11
+
12
+ from src.settings import FONTPATH
13
+
14
+
15
+ def count_characters(str: str):
16
+ """
17
+ Count the number of Japanese characters,
18
+ a single English character and a single number
19
+ equal to half the length of Japanese characters.
20
+ args:
21
+ s(string): the input of string
22
+ return(int):
23
+ the number of Japanese characters
24
+ """
25
+
26
+ count_zh = count_pu = 0
27
+ s_len = len(str)
28
+ en_dg_count = 0
29
+ for c in str:
30
+ if c in string.ascii_letters or c.isdigit() or c.isspace():
31
+ en_dg_count += 1
32
+ elif c.isalpha():
33
+ count_zh += 1
34
+ else:
35
+ count_pu += 1
36
+ return s_len - math.ceil(en_dg_count / 2)
37
+
38
+
39
+ def create_blank_img(img_h, img_w):
40
+ """
41
+ create new blank img
42
+ args:
43
+ img_h(int): the height of blank img
44
+ img_w(int): the width of blank img
45
+ return(Image|array):
46
+ blank image
47
+ """
48
+ blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
49
+ blank_img[:, img_w - 1 :] = 0
50
+ blank_img = Image.fromarray(blank_img).convert("RGB")
51
+ draw_txt = ImageDraw.Draw(blank_img)
52
+ return blank_img, draw_txt
53
+
54
+
55
+ def text_visual(
56
+ texts,
57
+ scores,
58
+ img_h=400,
59
+ img_w=600,
60
+ threshold=0.0,
61
+ font_path=FONTPATH,
62
+ ):
63
+ """
64
+ create new img with recognized text
65
+ args:
66
+ texts(list): the text will be draw
67
+ scores(list|None): corresponding score of each txt
68
+ img_h(int): the height of blank img
69
+ img_w(int): the width of blank img
70
+ font_path: the path of font which is used to draw text
71
+ return(Image|array): image with recognized text
72
+ """
73
+ if scores is not None:
74
+ assert len(texts) == len(
75
+ scores,
76
+ ), "The number of txts and corresponding scores must match"
77
+
78
+ blank_img, draw_txt = create_blank_img()
79
+
80
+ font_size = 20
81
+ txt_color = (0, 0, 0)
82
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
83
+
84
+ gap = font_size + 5
85
+ txt_img_list = []
86
+ count, index = 1, 0
87
+ for idx, txt in enumerate(texts):
88
+ index += 1
89
+ if scores[idx] < threshold or math.isnan(scores[idx]):
90
+ index -= 1
91
+ continue
92
+ first_line = True
93
+ while count_characters(txt) >= img_w // font_size - 4:
94
+ tmp = txt
95
+ txt = tmp[: img_w // font_size - 4]
96
+ if first_line:
97
+ new_txt = str(index) + ": " + txt
98
+ first_line = False
99
+ else:
100
+ new_txt = " " + txt
101
+ draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
102
+ txt = tmp[img_w // font_size - 4 :]
103
+ if count >= img_h // gap - 1:
104
+ txt_img_list.append(np.array(blank_img))
105
+ blank_img, draw_txt = create_blank_img()
106
+ count = 0
107
+ count += 1
108
+ if first_line:
109
+ new_txt = str(index) + ": " + txt + " " + "%.3f" % (scores[idx])
110
+ else:
111
+ new_txt = " " + txt + " " + "%.3f" % (scores[idx])
112
+ draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
113
+ # whether add new blank img or not
114
+ if count >= img_h // gap - 1 and idx + 1 < len(texts):
115
+ txt_img_list.append(np.array(blank_img))
116
+ blank_img, draw_txt = create_blank_img()
117
+ count = 0
118
+ count += 1
119
+ txt_img_list.append(np.array(blank_img))
120
+ if len(txt_img_list) == 1:
121
+ blank_img = np.array(txt_img_list[0])
122
+ else:
123
+ blank_img = np.concatenate(txt_img_list, axis=1)
124
+ return np.array(blank_img)
125
+
126
+
127
+ def resize_img(img, input_size=600):
128
+ """
129
+ resize img and limit the longest side of the image to input_size
130
+ args:
131
+ img(np.array): original image
132
+ input_size(int): new size of the longest side of the image
133
+ return(Image|array):
134
+ a new-size image
135
+ """
136
+ img = np.array(img)
137
+ im_shape = img.shape
138
+ im_size_max = np.max(im_shape[0:2])
139
+ im_scale = float(input_size) / float(im_size_max)
140
+ img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
141
+ return img
142
+
143
+
144
+ def draw_ocr(
145
+ image,
146
+ boxes,
147
+ txts=None,
148
+ scores=None,
149
+ drop_score=0.0,
150
+ font_path=FONTPATH,
151
+ ):
152
+ """
153
+ Visualize the results of OCR detection and recognition
154
+ args:
155
+ image(Image|array): RGB image
156
+ boxes(list): boxes with shape(N, 4, 2)
157
+ txts(list): the texts
158
+ scores(list): txxs corresponding scores
159
+ drop_score(float): only scores > drop_threshold will be visualized
160
+ font_path: the path of font which is used to draw text
161
+ return(Image|array):
162
+ the visualized img
163
+ """
164
+ if scores is None:
165
+ scores = [1] * len(boxes)
166
+ box_num = len(boxes)
167
+ for i in range(box_num):
168
+ if scores is not None and (
169
+ scores[i] < drop_score or math.isnan(scores[i])
170
+ ):
171
+ continue
172
+ box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
173
+ image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
174
+ if txts is not None:
175
+ img = np.array(image)
176
+ txt_img = text_visual(
177
+ txts,
178
+ scores,
179
+ img_h=img.shape[0],
180
+ img_w=600,
181
+ threshold=drop_score,
182
+ font_path=font_path,
183
+ )
184
+ img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
185
+ return img
186
+ return image
187
+
188
+
189
+ def draw_ocr_2(img, results):
190
+ """
191
+ Visualize the results of OCR detection and recognition
192
+ args:
193
+ image(Image|array): RGB image
194
+ results(list): boxes with shape(N, 4, 2), texts and scores
195
+ return(Image|array):
196
+ the visualized img
197
+ """
198
+ img = np.asarray(img)
199
+ # img_box = np.ones((img.shape[0], img.shape[1], 3), np.uint8) * 255
200
+ img_text = np.ones((img.shape[0], img.shape[1], 3), np.uint8) * 255
201
+
202
+ for line in results:
203
+ text = line[1][0]
204
+ # score = line[1][1]
205
+
206
+ top = int(min(line[0][0][1], line[0][1][1]))
207
+ # bottom = int(max(line[0][2][1], line[0][3][1]))
208
+ left = int(min(line[0][0][0], line[0][3][0]))
209
+ right = int(max(line[0][1][0], line[0][2][0]))
210
+
211
+ text_size = int((right - left) / len(text))
212
+ color = (
213
+ np.random.randint(0, 255),
214
+ np.random.randint(0, 255),
215
+ np.random.randint(0, 255),
216
+ )
217
+
218
+ box = np.reshape(np.array(line[0]), [-1, 1, 2]).astype(np.int64)
219
+ img = cv2.polylines(np.array(img), [box], True, color, 2)
220
+
221
+ img_text = place_text(img_text, text, (left, top), text_size, color)
222
+
223
+ return img, img_text
224
+
225
+
226
+ def place_text(img, text, top_left_point, text_size, text_color):
227
+ """
228
+ Put text into image
229
+ args:
230
+ img(Image|array): RGB image
231
+ text(list): text to be put
232
+ top_left_point(array): top-left point to start the text
233
+ text_color(tuple): text color in ()
234
+ return(Image|array):
235
+ the visualized img
236
+ """
237
+ font = ImageFont.truetype(FONTPATH, text_size, encoding="utf-8")
238
+ img_pil = Image.fromarray(img)
239
+ draw = ImageDraw.Draw(img_pil)
240
+ draw.text(top_left_point, text, fill=text_color, font=font)
241
+ return np.array(img_pil)
242
+
243
+
244
+ def visualize_result(result, image):
245
+ """
246
+ make visualization in image foramt
247
+ args:
248
+ result(array): RGB image
249
+ img_path(str): path to input image
250
+ out_path(str): path to output image
251
+ return(Image|array):
252
+ the visualized img
253
+ """
254
+ result = result[0]
255
+ # image = Image.open(img_path).convert("RGB")
256
+ image = Image.fromarray(image)
257
+ """
258
+ boxes = [line[0] for line in result]
259
+ txts = [line[1][0] for line in result]
260
+ scores = [line[1][1] for line in result]
261
+ """
262
+ img_boxes, img_text = draw_ocr_2(image, result)
263
+ if isinstance(image, str):
264
+ img_combination = np.concatenate(
265
+ [np.array(img_boxes), np.array(img_text)],
266
+ axis=1,
267
+ )
268
+ img_combination = Image.fromarray(img_combination)
269
+ img_text.save(image + "_out.jpg")
270
+
271
+ return img_boxes, img_text