Container-Number-OCR / src /app /text_recognition.py
Alimustoofaa's picture
first commit
7ee7e3a
"""
@Author : Ali Mustofa HALOTEC
@Module : Character Recognition Neural Network
@Created on : 2 Agust 2022
"""
#!/usr/bin/env python3
# Path: src/apps/char_recognition.py
import os
import cv2
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .crnn import CRNN
from .decoder import ctc_decode
try:
from src.utils.utils import download_and_unzip_model
except ImportError:
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from utils.utils import download_and_unzip_model
class TextRecognition:
def __init__(self, root_path:str, model_config:dict, jic: bool=True) -> None:
self.jic = jic
self.root_path = root_path
self.model_config = model_config
self.model_name = f'{root_path}/{model_config["filename"]}'
self.classes = {i+1:v for i,v in enumerate(model_config['classes'])}
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = self.__load_model()
if jic: self.model = self.__jic_trace(self.model)
@staticmethod
def __crnn_model(config) -> nn.Module:
model = CRNN(
img_channel = 1,
img_height = config['img_height'],
img_width = config['img_width'],
num_class = len(config['classes'])+1,
map_to_seq_hidden = config['map_to_seq_hidden'],
rnn_hidden = config['rnn_hidden'],
leaky_relu = config['leaky_relu']
)
return model
@staticmethod
def __jic_trace(model:nn.Module) -> torch.jit.TracedModule:
'''
JIT tracing
@params:
- model: nn.Module
'''
return torch.jit.trace(model, torch.rand(1, 1, 32, 100))
@staticmethod
def __check_model(root_path:str, model_config:dict) -> None:
if not os.path.isfile(f'{root_path}/{model_config["filename"]}'):
download_and_unzip_model(
root_dir = root_path,
name = model_config['filename'],
url = model_config['url'],
file_size = model_config['file_size'],
unzip = False
)
else: print('Load model ...')
def __load_model(self) -> nn.Module:
'''
Load model from file
@return:
- model: nn.Module
'''
self.__check_model(self.root_path, self.model_config)
model = self.__crnn_model(self.model_config)
model.load_state_dict(torch.load(self.model_name, map_location=self.device))
model.to(self.device)
return model.eval()
@staticmethod
def __image_transform(image:np.ndarray, height: int=32, width: int=100) -> torch.Tensor:
'''
Image transform
@params:
- image: np.ndarray
@return:
- image: torch.Tensor
'''
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (width, height))
image = image.reshape(1, height, width)
image = (image / 127.5) - 1.0
image = torch.FloatTensor(image)
return image.unsqueeze(0)
def recognition(
self,
image: np.array,
decode: str = 'beam_search',
beam_size: int = 10
) -> dict:
'''
Recognition text from image
@params:
- image: np.ndarray
- decode: str -> ['beam_search', 'greedy', 'prefix_beam_search']
- beam_size: int -> beam size for beam search
@return:
- result: dict -> {'text': str, 'confidence': float}
'''
assert decode in ['beam_search', 'greedy', 'prefix_beam_search'], 'Decode Failed'
image_t = self.__image_transform(image)
# recognize
with torch.no_grad():
output = self.model(image_t)
log_probs = F.log_softmax(output, dim=2)
# decode
preds = ctc_decode(
log_probs, method=decode, beam_size=beam_size,
blank=0, label2char=self.classes)
# calculate confidence
exps = torch.exp(log_probs)
try:
probs = sum(torch.max(exps, dim=2)[0]/len(exps)).detach().numpy()[0]
except RuntimeError:
probs = sum(torch.max(exps, dim=2)[0]/len(exps)).cpu().numpy()[0]
preds, conf = ''.join(preds[0]), round(probs,2)
return {'text': preds, 'confidence': conf}
if __name__ == '__main__':
import time
import string
root_path = os.path.expanduser('~/.Halotec/Models')
model_config = {
'filename' : 'crnn_008000.pt',
'classes' : string.digits+string.ascii_uppercase+'. ',
'url' : None,
'file_size' : 592694,
'img_height': 32,
'img_width' : 100,
'map_to_seq_hidden': 64,
'rnn_hidden': 256,
'leaky_relu': False
}
text_recognition = TextRecognition(root_path, model_config, jic=True)
image = cv2.imread('./images/12022041114405685_0.jpg')
start = time.time()
for i in range(10):
result = text_recognition.recognition(image, decode='beam_search', beam_size=10)
print(result)
print(time.time() - start)