Spaces:
Build error
Build error
Commit ·
7ee7e3a
1
Parent(s): 29496ec
first commit
Browse files- .gitignore +109 -0
- app.py +39 -0
- requirements.txt +11 -0
- src/app/__init__.py +0 -0
- src/app/crnn.py +90 -0
- src/app/decoder.py +148 -0
- src/app/main.py +28 -0
- src/app/text_recognition.py +161 -0
- src/utils/__init__.py +0 -0
- src/utils/utils.py +73 -0
.gitignore
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
env/
|
| 12 |
+
build/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
|
| 28 |
+
# PyInstaller
|
| 29 |
+
# Usually these files are written by a python script from a template
|
| 30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 31 |
+
*.manifest
|
| 32 |
+
*.spec
|
| 33 |
+
|
| 34 |
+
# Installer logs
|
| 35 |
+
pip-log.txt
|
| 36 |
+
pip-delete-this-directory.txt
|
| 37 |
+
|
| 38 |
+
# Unit test / coverage reports
|
| 39 |
+
htmlcov/
|
| 40 |
+
.tox/
|
| 41 |
+
.coverage
|
| 42 |
+
.coverage.*
|
| 43 |
+
.cache
|
| 44 |
+
nosetests.xml
|
| 45 |
+
coverage.xml
|
| 46 |
+
*.cover
|
| 47 |
+
.hypothesis/
|
| 48 |
+
|
| 49 |
+
# Translations
|
| 50 |
+
*.mo
|
| 51 |
+
*.pot
|
| 52 |
+
|
| 53 |
+
# Django stuff:
|
| 54 |
+
*.log
|
| 55 |
+
local_settings.py
|
| 56 |
+
|
| 57 |
+
# Flask stuff:
|
| 58 |
+
instance/
|
| 59 |
+
.webassets-cache
|
| 60 |
+
|
| 61 |
+
# Scrapy stuff:
|
| 62 |
+
.scrapy
|
| 63 |
+
|
| 64 |
+
# Sphinx documentation
|
| 65 |
+
docs/_build/
|
| 66 |
+
|
| 67 |
+
# PyBuilder
|
| 68 |
+
target/
|
| 69 |
+
|
| 70 |
+
# Jupyter Notebook
|
| 71 |
+
.ipynb_checkpoints
|
| 72 |
+
|
| 73 |
+
# pyenv
|
| 74 |
+
.python-version
|
| 75 |
+
|
| 76 |
+
# celery beat schedule file
|
| 77 |
+
celerybeat-schedule
|
| 78 |
+
|
| 79 |
+
# SageMath parsed files
|
| 80 |
+
*.sage.py
|
| 81 |
+
|
| 82 |
+
# dotenv
|
| 83 |
+
.env
|
| 84 |
+
|
| 85 |
+
# virtualenv
|
| 86 |
+
.venv
|
| 87 |
+
venv/
|
| 88 |
+
ENV/
|
| 89 |
+
|
| 90 |
+
# Spyder project settings
|
| 91 |
+
.spyderproject
|
| 92 |
+
.spyproject
|
| 93 |
+
|
| 94 |
+
# Rope project settings
|
| 95 |
+
.ropeproject
|
| 96 |
+
|
| 97 |
+
# mkdocs documentation
|
| 98 |
+
/site
|
| 99 |
+
|
| 100 |
+
# mypy
|
| 101 |
+
.mypy_cache/
|
| 102 |
+
|
| 103 |
+
# macOS
|
| 104 |
+
*.DS_Store
|
| 105 |
+
|
| 106 |
+
# IDEs
|
| 107 |
+
.vscode/
|
| 108 |
+
.vs/
|
| 109 |
+
.idea/
|
app.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import string
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from src.app.text_recognition import TextRecognition
|
| 6 |
+
|
| 7 |
+
root_path = os.path.expanduser('~/.Halotec/Models')
|
| 8 |
+
|
| 9 |
+
model_config = {
|
| 10 |
+
'filename' : 'crnn_008000.pt',
|
| 11 |
+
'classes' : string.digits+string.ascii_uppercase+'. ',
|
| 12 |
+
'url' : 'https://github.com/Alimustoofaa/Research-OCR-License-Plate/releases/download/crnn/crnn_008000.pt',
|
| 13 |
+
'file_size' : 31379595,
|
| 14 |
+
'img_height': 32,
|
| 15 |
+
'img_width' : 100,
|
| 16 |
+
'map_to_seq_hidden': 64,
|
| 17 |
+
'rnn_hidden': 256,
|
| 18 |
+
'leaky_relu': False
|
| 19 |
+
}
|
| 20 |
+
model_ocr = TextRecognition(root_path, model_config, jic=True)
|
| 21 |
+
|
| 22 |
+
def recognition(image):
|
| 23 |
+
result = model_ocr.recognition(image, decode='beam_search', beam_size=10)
|
| 24 |
+
txt, conf = result['text'], result['confidence']
|
| 25 |
+
return txt, conf
|
| 26 |
+
|
| 27 |
+
title = "OCR License Plate Indonesia"
|
| 28 |
+
css = ".image-preview {height: auto !important;}"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
iface = gr.Interface(
|
| 32 |
+
title = title,
|
| 33 |
+
fn = recognition,
|
| 34 |
+
inputs = [gr.Image()],
|
| 35 |
+
outputs = ['text', 'text'],
|
| 36 |
+
css=css
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
iface.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
docopt==0.6.2
|
| 2 |
+
numpy==1.17.2
|
| 3 |
+
opencv-python==4.5.1.48
|
| 4 |
+
pillow==6.1.0
|
| 5 |
+
scipy==1.5.2
|
| 6 |
+
six==1.12.0
|
| 7 |
+
#torch==1.2.0
|
| 8 |
+
tqdm==4.49.0
|
| 9 |
+
torch
|
| 10 |
+
torchvision
|
| 11 |
+
torchaudio
|
src/app/__init__.py
ADDED
|
File without changes
|
src/app/crnn.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CRNN(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(self, img_channel, img_height, img_width, num_class,
|
| 7 |
+
map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
|
| 8 |
+
super(CRNN, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.cnn, (output_channel, output_height, output_width) = \
|
| 11 |
+
self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
|
| 12 |
+
|
| 13 |
+
self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
|
| 14 |
+
|
| 15 |
+
self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
|
| 16 |
+
self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
|
| 17 |
+
|
| 18 |
+
self.dense = nn.Linear(2 * rnn_hidden, num_class)
|
| 19 |
+
|
| 20 |
+
def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
|
| 21 |
+
assert img_height % 16 == 0
|
| 22 |
+
assert img_width % 4 == 0
|
| 23 |
+
|
| 24 |
+
channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
|
| 25 |
+
kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
|
| 26 |
+
strides = [1, 1, 1, 1, 1, 1, 1]
|
| 27 |
+
paddings = [1, 1, 1, 1, 1, 1, 0]
|
| 28 |
+
|
| 29 |
+
cnn = nn.Sequential()
|
| 30 |
+
|
| 31 |
+
def conv_relu(i, batch_norm=False):
|
| 32 |
+
# shape of input: (batch, input_channel, height, width)
|
| 33 |
+
input_channel = channels[i]
|
| 34 |
+
output_channel = channels[i+1]
|
| 35 |
+
|
| 36 |
+
cnn.add_module(
|
| 37 |
+
f'conv{i}',
|
| 38 |
+
nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
if batch_norm:
|
| 42 |
+
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))
|
| 43 |
+
|
| 44 |
+
relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
|
| 45 |
+
cnn.add_module(f'relu{i}', relu)
|
| 46 |
+
|
| 47 |
+
# size of image: (channel, height, width) = (img_channel, img_height, img_width)
|
| 48 |
+
conv_relu(0)
|
| 49 |
+
cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
|
| 50 |
+
# (64, img_height // 2, img_width // 2)
|
| 51 |
+
|
| 52 |
+
conv_relu(1)
|
| 53 |
+
cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
|
| 54 |
+
# (128, img_height // 4, img_width // 4)
|
| 55 |
+
|
| 56 |
+
conv_relu(2)
|
| 57 |
+
conv_relu(3)
|
| 58 |
+
cnn.add_module(
|
| 59 |
+
'pooling2',
|
| 60 |
+
nn.MaxPool2d(kernel_size=(2, 1))
|
| 61 |
+
) # (256, img_height // 8, img_width // 4)
|
| 62 |
+
|
| 63 |
+
conv_relu(4, batch_norm=True)
|
| 64 |
+
conv_relu(5, batch_norm=True)
|
| 65 |
+
cnn.add_module(
|
| 66 |
+
'pooling3',
|
| 67 |
+
nn.MaxPool2d(kernel_size=(2, 1))
|
| 68 |
+
) # (512, img_height // 16, img_width // 4)
|
| 69 |
+
|
| 70 |
+
conv_relu(6) # (512, img_height // 16 - 1, img_width // 4 - 1)
|
| 71 |
+
|
| 72 |
+
output_channel, output_height, output_width = \
|
| 73 |
+
channels[-1], img_height // 16 - 1, img_width // 4 - 1
|
| 74 |
+
return cnn, (output_channel, output_height, output_width)
|
| 75 |
+
|
| 76 |
+
def forward(self, images):
|
| 77 |
+
# shape of images: (batch, channel, height, width)
|
| 78 |
+
|
| 79 |
+
conv = self.cnn(images)
|
| 80 |
+
batch, channel, height, width = conv.size()
|
| 81 |
+
|
| 82 |
+
conv = conv.view(batch, channel * height, width)
|
| 83 |
+
conv = conv.permute(2, 0, 1) # (width, batch, feature)
|
| 84 |
+
seq = self.map_to_seq(conv)
|
| 85 |
+
|
| 86 |
+
recurrent, _ = self.rnn1(seq)
|
| 87 |
+
recurrent, _ = self.rnn2(recurrent)
|
| 88 |
+
|
| 89 |
+
output = self.dense(recurrent)
|
| 90 |
+
return output # shape: (seq_len, batch, num_class)
|
src/app/decoder.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy.special import logsumexp # log(p1 + p2) = logsumexp([log_p1, log_p2])
|
| 6 |
+
|
| 7 |
+
NINF = -1 * float('inf')
|
| 8 |
+
DEFAULT_EMISSION_THRESHOLD = 0.01
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _reconstruct(labels, blank=0):
|
| 12 |
+
new_labels = []
|
| 13 |
+
# merge same labels
|
| 14 |
+
previous = None
|
| 15 |
+
for l in labels:
|
| 16 |
+
if l != previous:
|
| 17 |
+
new_labels.append(l)
|
| 18 |
+
previous = l
|
| 19 |
+
# delete blank
|
| 20 |
+
new_labels = [l for l in new_labels if l != blank]
|
| 21 |
+
|
| 22 |
+
return new_labels
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def greedy_decode(emission_log_prob, blank=0, **kwargs):
|
| 26 |
+
labels = np.argmax(emission_log_prob, axis=-1)
|
| 27 |
+
labels = _reconstruct(labels, blank=blank)
|
| 28 |
+
return labels
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def beam_search_decode(emission_log_prob, blank=0, **kwargs):
|
| 32 |
+
beam_size = kwargs['beam_size']
|
| 33 |
+
emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))
|
| 34 |
+
|
| 35 |
+
length, class_count = emission_log_prob.shape
|
| 36 |
+
|
| 37 |
+
beams = [([], 0)] # (prefix, accumulated_log_prob)
|
| 38 |
+
for t in range(length):
|
| 39 |
+
new_beams = []
|
| 40 |
+
for prefix, accumulated_log_prob in beams:
|
| 41 |
+
for c in range(class_count):
|
| 42 |
+
log_prob = emission_log_prob[t, c]
|
| 43 |
+
if log_prob < emission_threshold:
|
| 44 |
+
continue
|
| 45 |
+
new_prefix = prefix + [c]
|
| 46 |
+
# log(p1 * p2) = log_p1 + log_p2
|
| 47 |
+
new_accu_log_prob = accumulated_log_prob + log_prob
|
| 48 |
+
new_beams.append((new_prefix, new_accu_log_prob))
|
| 49 |
+
|
| 50 |
+
# sorted by accumulated_log_prob
|
| 51 |
+
new_beams.sort(key=lambda x: x[1], reverse=True)
|
| 52 |
+
beams = new_beams[:beam_size]
|
| 53 |
+
|
| 54 |
+
# sum up beams to produce labels
|
| 55 |
+
total_accu_log_prob = {}
|
| 56 |
+
for prefix, accu_log_prob in beams:
|
| 57 |
+
labels = tuple(_reconstruct(prefix, blank))
|
| 58 |
+
# log(p1 + p2) = logsumexp([log_p1, log_p2])
|
| 59 |
+
total_accu_log_prob[labels] = \
|
| 60 |
+
logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])
|
| 61 |
+
|
| 62 |
+
labels_beams = [(list(labels), accu_log_prob)
|
| 63 |
+
for labels, accu_log_prob in total_accu_log_prob.items()]
|
| 64 |
+
labels_beams.sort(key=lambda x: x[1], reverse=True)
|
| 65 |
+
labels = labels_beams[0][0]
|
| 66 |
+
|
| 67 |
+
return labels
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def prefix_beam_decode(emission_log_prob, blank=0, **kwargs):
|
| 71 |
+
beam_size = kwargs['beam_size']
|
| 72 |
+
emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))
|
| 73 |
+
|
| 74 |
+
length, class_count = emission_log_prob.shape
|
| 75 |
+
|
| 76 |
+
beams = [(tuple(), (0, NINF))] # (prefix, (blank_log_prob, non_blank_log_prob))
|
| 77 |
+
# initial of beams: (empty_str, (log(1.0), log(0.0)))
|
| 78 |
+
|
| 79 |
+
for t in range(length):
|
| 80 |
+
new_beams_dict = defaultdict(lambda: (NINF, NINF)) # log(0.0) = NINF
|
| 81 |
+
|
| 82 |
+
for prefix, (lp_b, lp_nb) in beams:
|
| 83 |
+
for c in range(class_count):
|
| 84 |
+
log_prob = emission_log_prob[t, c]
|
| 85 |
+
if log_prob < emission_threshold:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
end_t = prefix[-1] if prefix else None
|
| 89 |
+
|
| 90 |
+
# if new_prefix == prefix
|
| 91 |
+
new_lp_b, new_lp_nb = new_beams_dict[prefix]
|
| 92 |
+
|
| 93 |
+
if c == blank:
|
| 94 |
+
new_beams_dict[prefix] = (
|
| 95 |
+
logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]),
|
| 96 |
+
new_lp_nb
|
| 97 |
+
)
|
| 98 |
+
continue
|
| 99 |
+
if c == end_t:
|
| 100 |
+
new_beams_dict[prefix] = (
|
| 101 |
+
new_lp_b,
|
| 102 |
+
logsumexp([new_lp_nb, lp_nb + log_prob])
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# if new_prefix == prefix + (c,)
|
| 106 |
+
new_prefix = prefix + (c,)
|
| 107 |
+
new_lp_b, new_lp_nb = new_beams_dict[new_prefix]
|
| 108 |
+
|
| 109 |
+
if c != end_t:
|
| 110 |
+
new_beams_dict[new_prefix] = (
|
| 111 |
+
new_lp_b,
|
| 112 |
+
logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob])
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
new_beams_dict[new_prefix] = (
|
| 116 |
+
new_lp_b,
|
| 117 |
+
logsumexp([new_lp_nb, lp_b + log_prob])
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# sorted by log(blank_prob + non_blank_prob)
|
| 121 |
+
beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True)
|
| 122 |
+
beams = beams[:beam_size]
|
| 123 |
+
|
| 124 |
+
labels = list(beams[0][0])
|
| 125 |
+
return labels
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def ctc_decode(log_probs, label2char=None, blank=0, method='beam_search', beam_size=10):
|
| 129 |
+
try:
|
| 130 |
+
emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
|
| 131 |
+
except RuntimeError:
|
| 132 |
+
emission_log_probs = np.transpose(log_probs.detach().numpy(), (1, 0, 2))
|
| 133 |
+
# size of emission_log_probs: (batch, length, class)
|
| 134 |
+
|
| 135 |
+
decoders = {
|
| 136 |
+
'greedy': greedy_decode,
|
| 137 |
+
'beam_search': beam_search_decode,
|
| 138 |
+
'prefix_beam_search': prefix_beam_decode,
|
| 139 |
+
}
|
| 140 |
+
decoder = decoders[method]
|
| 141 |
+
|
| 142 |
+
decoded_list = []
|
| 143 |
+
for emission_log_prob in emission_log_probs:
|
| 144 |
+
decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size)
|
| 145 |
+
if label2char:
|
| 146 |
+
decoded = [label2char[l] for l in decoded]
|
| 147 |
+
decoded_list.append(decoded)
|
| 148 |
+
return decoded_list
|
src/app/main.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from text_recognition import TextRecognition
|
| 2 |
+
|
| 3 |
+
if __name__ == '__main__':
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
import time
|
| 7 |
+
import string
|
| 8 |
+
|
| 9 |
+
root_path = os.path.expanduser('~/.Halotec/Models')
|
| 10 |
+
|
| 11 |
+
model_config = {
|
| 12 |
+
'filename' : 'crnn_008000.pt',
|
| 13 |
+
'classes' : string.digits+string.ascii_uppercase+'. ',
|
| 14 |
+
'url' : None,
|
| 15 |
+
'file_size' : 592694,
|
| 16 |
+
'img_height': 32,
|
| 17 |
+
'img_width' : 100,
|
| 18 |
+
'map_to_seq_hidden': 64,
|
| 19 |
+
'rnn_hidden': 256,
|
| 20 |
+
'leaky_relu': False
|
| 21 |
+
}
|
| 22 |
+
text_recognition = TextRecognition(root_path, model_config, jic=True)
|
| 23 |
+
image = cv2.imread('./images/12022041113414598_14.jpg')
|
| 24 |
+
start = time.time()
|
| 25 |
+
for i in range(10):
|
| 26 |
+
result = text_recognition.recognition(image, decode='beam_search', beam_size=10)
|
| 27 |
+
print(result)
|
| 28 |
+
print(time.time() - start)
|
src/app/text_recognition.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
@Author : Ali Mustofa HALOTEC
|
| 3 |
+
@Module : Character Recognition Neural Network
|
| 4 |
+
@Created on : 2 Agust 2022
|
| 5 |
+
"""
|
| 6 |
+
#!/usr/bin/env python3
|
| 7 |
+
# Path: src/apps/char_recognition.py
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import cv2
|
| 11 |
+
import sys
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from .crnn import CRNN
|
| 18 |
+
from .decoder import ctc_decode
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from src.utils.utils import download_and_unzip_model
|
| 22 |
+
except ImportError:
|
| 23 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
| 25 |
+
from utils.utils import download_and_unzip_model
|
| 26 |
+
|
| 27 |
+
class TextRecognition:
|
| 28 |
+
def __init__(self, root_path:str, model_config:dict, jic: bool=True) -> None:
|
| 29 |
+
self.jic = jic
|
| 30 |
+
self.root_path = root_path
|
| 31 |
+
self.model_config = model_config
|
| 32 |
+
self.model_name = f'{root_path}/{model_config["filename"]}'
|
| 33 |
+
self.classes = {i+1:v for i,v in enumerate(model_config['classes'])}
|
| 34 |
+
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 35 |
+
self.model = self.__load_model()
|
| 36 |
+
if jic: self.model = self.__jic_trace(self.model)
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def __crnn_model(config) -> nn.Module:
|
| 40 |
+
model = CRNN(
|
| 41 |
+
img_channel = 1,
|
| 42 |
+
img_height = config['img_height'],
|
| 43 |
+
img_width = config['img_width'],
|
| 44 |
+
num_class = len(config['classes'])+1,
|
| 45 |
+
map_to_seq_hidden = config['map_to_seq_hidden'],
|
| 46 |
+
rnn_hidden = config['rnn_hidden'],
|
| 47 |
+
leaky_relu = config['leaky_relu']
|
| 48 |
+
)
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def __jic_trace(model:nn.Module) -> torch.jit.TracedModule:
|
| 53 |
+
'''
|
| 54 |
+
JIT tracing
|
| 55 |
+
@params:
|
| 56 |
+
- model: nn.Module
|
| 57 |
+
'''
|
| 58 |
+
return torch.jit.trace(model, torch.rand(1, 1, 32, 100))
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def __check_model(root_path:str, model_config:dict) -> None:
|
| 62 |
+
if not os.path.isfile(f'{root_path}/{model_config["filename"]}'):
|
| 63 |
+
download_and_unzip_model(
|
| 64 |
+
root_dir = root_path,
|
| 65 |
+
name = model_config['filename'],
|
| 66 |
+
url = model_config['url'],
|
| 67 |
+
file_size = model_config['file_size'],
|
| 68 |
+
unzip = False
|
| 69 |
+
)
|
| 70 |
+
else: print('Load model ...')
|
| 71 |
+
|
| 72 |
+
def __load_model(self) -> nn.Module:
|
| 73 |
+
'''
|
| 74 |
+
Load model from file
|
| 75 |
+
@return:
|
| 76 |
+
- model: nn.Module
|
| 77 |
+
'''
|
| 78 |
+
self.__check_model(self.root_path, self.model_config)
|
| 79 |
+
model = self.__crnn_model(self.model_config)
|
| 80 |
+
model.load_state_dict(torch.load(self.model_name, map_location=self.device))
|
| 81 |
+
model.to(self.device)
|
| 82 |
+
return model.eval()
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def __image_transform(image:np.ndarray, height: int=32, width: int=100) -> torch.Tensor:
|
| 86 |
+
'''
|
| 87 |
+
Image transform
|
| 88 |
+
@params:
|
| 89 |
+
- image: np.ndarray
|
| 90 |
+
@return:
|
| 91 |
+
- image: torch.Tensor
|
| 92 |
+
'''
|
| 93 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 94 |
+
image = cv2.resize(image, (width, height))
|
| 95 |
+
image = image.reshape(1, height, width)
|
| 96 |
+
image = (image / 127.5) - 1.0
|
| 97 |
+
image = torch.FloatTensor(image)
|
| 98 |
+
return image.unsqueeze(0)
|
| 99 |
+
|
| 100 |
+
def recognition(
|
| 101 |
+
self,
|
| 102 |
+
image: np.array,
|
| 103 |
+
decode: str = 'beam_search',
|
| 104 |
+
beam_size: int = 10
|
| 105 |
+
) -> dict:
|
| 106 |
+
'''
|
| 107 |
+
Recognition text from image
|
| 108 |
+
@params:
|
| 109 |
+
- image: np.ndarray
|
| 110 |
+
- decode: str -> ['beam_search', 'greedy', 'prefix_beam_search']
|
| 111 |
+
- beam_size: int -> beam size for beam search
|
| 112 |
+
@return:
|
| 113 |
+
- result: dict -> {'text': str, 'confidence': float}
|
| 114 |
+
'''
|
| 115 |
+
|
| 116 |
+
assert decode in ['beam_search', 'greedy', 'prefix_beam_search'], 'Decode Failed'
|
| 117 |
+
|
| 118 |
+
image_t = self.__image_transform(image)
|
| 119 |
+
# recognize
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
output = self.model(image_t)
|
| 122 |
+
log_probs = F.log_softmax(output, dim=2)
|
| 123 |
+
# decode
|
| 124 |
+
preds = ctc_decode(
|
| 125 |
+
log_probs, method=decode, beam_size=beam_size,
|
| 126 |
+
blank=0, label2char=self.classes)
|
| 127 |
+
# calculate confidence
|
| 128 |
+
exps = torch.exp(log_probs)
|
| 129 |
+
try:
|
| 130 |
+
probs = sum(torch.max(exps, dim=2)[0]/len(exps)).detach().numpy()[0]
|
| 131 |
+
except RuntimeError:
|
| 132 |
+
probs = sum(torch.max(exps, dim=2)[0]/len(exps)).cpu().numpy()[0]
|
| 133 |
+
|
| 134 |
+
preds, conf = ''.join(preds[0]), round(probs,2)
|
| 135 |
+
return {'text': preds, 'confidence': conf}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == '__main__':
|
| 139 |
+
import time
|
| 140 |
+
import string
|
| 141 |
+
|
| 142 |
+
root_path = os.path.expanduser('~/.Halotec/Models')
|
| 143 |
+
|
| 144 |
+
model_config = {
|
| 145 |
+
'filename' : 'crnn_008000.pt',
|
| 146 |
+
'classes' : string.digits+string.ascii_uppercase+'. ',
|
| 147 |
+
'url' : None,
|
| 148 |
+
'file_size' : 592694,
|
| 149 |
+
'img_height': 32,
|
| 150 |
+
'img_width' : 100,
|
| 151 |
+
'map_to_seq_hidden': 64,
|
| 152 |
+
'rnn_hidden': 256,
|
| 153 |
+
'leaky_relu': False
|
| 154 |
+
}
|
| 155 |
+
text_recognition = TextRecognition(root_path, model_config, jic=True)
|
| 156 |
+
image = cv2.imread('./images/12022041114405685_0.jpg')
|
| 157 |
+
start = time.time()
|
| 158 |
+
for i in range(10):
|
| 159 |
+
result = text_recognition.recognition(image, decode='beam_search', beam_size=10)
|
| 160 |
+
print(result)
|
| 161 |
+
print(time.time() - start)
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import base64
|
| 4 |
+
import requests
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from zipfile import ZipFile
|
| 9 |
+
|
| 10 |
+
def download_and_unzip_model(root_dir:str, name:str,
|
| 11 |
+
url:str, file_size:int, unzip:bool = False):
|
| 12 |
+
'''
|
| 13 |
+
Checking model in model_path
|
| 14 |
+
download model if file not found
|
| 15 |
+
@params:
|
| 16 |
+
root_dir(str): The root directory of model.
|
| 17 |
+
name(str): The name of model.
|
| 18 |
+
url(str): The url of model.
|
| 19 |
+
file_size(int): The size of model.
|
| 20 |
+
unzip(bool): Unzip the model or not.
|
| 21 |
+
'''
|
| 22 |
+
Path(root_dir).mkdir(parents=True, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# check if model is already or not
|
| 25 |
+
print(f'Downloading {root_dir.split("/")[-1]} model, please wait.')
|
| 26 |
+
response = requests.get(url, stream=True)
|
| 27 |
+
|
| 28 |
+
progress = tqdm(response.iter_content(1024),
|
| 29 |
+
f'Downloading model',
|
| 30 |
+
total=file_size, unit='B',
|
| 31 |
+
unit_scale=True, unit_divisor=1024)
|
| 32 |
+
save_dir = f'{root_dir}/{name}'
|
| 33 |
+
with open(save_dir, 'wb') as f:
|
| 34 |
+
for data in progress:
|
| 35 |
+
f.write(data)
|
| 36 |
+
progress.update(len(data))
|
| 37 |
+
print(f'Done downloading {root_dir.split("/")[-1]} model.')
|
| 38 |
+
|
| 39 |
+
# unzip model
|
| 40 |
+
if unzip:
|
| 41 |
+
with ZipFile(save_dir, 'r') as zip_obj:
|
| 42 |
+
zip_obj.extractall(root_dir)
|
| 43 |
+
print(f'Done unzip {root_dir.split("/")[-1]} model.')
|
| 44 |
+
os.remove(save_dir)
|
| 45 |
+
|
| 46 |
+
def encode_image2string(image):
|
| 47 |
+
image_list = cv2.imencode('.jpg', image)[1]
|
| 48 |
+
image_bytes = image_list.tobytes()
|
| 49 |
+
image_encoded = base64.b64encode(image_bytes)
|
| 50 |
+
return image_encoded
|
| 51 |
+
|
| 52 |
+
def decode_string2image(image_encoded):
|
| 53 |
+
jpg_original = base64.b64decode(image_encoded)
|
| 54 |
+
jpg_as_np = np.frombuffer(jpg_original, dtype=np.uint8)
|
| 55 |
+
image = cv2.imdecode(jpg_as_np, flags=1)
|
| 56 |
+
return image
|
| 57 |
+
|
| 58 |
+
def resize_image(image, size_percent):
|
| 59 |
+
'''
|
| 60 |
+
Resize an image so that its longest edge equals to the given size.
|
| 61 |
+
Args:
|
| 62 |
+
image(cv2.Image): The input image.
|
| 63 |
+
size_percent(int): The size of longest edge.
|
| 64 |
+
Returns:
|
| 65 |
+
image(cv2.Image): The output image.
|
| 66 |
+
'''
|
| 67 |
+
width = int(image.shape[1] * size_percent / 100)
|
| 68 |
+
height = int(image.shape[0] * size_percent / 100)
|
| 69 |
+
dim = (width, height)
|
| 70 |
+
|
| 71 |
+
# resize image
|
| 72 |
+
resized = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
|
| 73 |
+
return resized
|