""" @Author : Ali Mustofa HALOTEC @Module : Character Recognition Neural Network @Created on : 2 Agust 2022 """ #!/usr/bin/env python3 # Path: src/apps/char_recognition.py import os import cv2 import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .crnn import CRNN from .decoder import ctc_decode try: from src.utils.utils import download_and_unzip_model except ImportError: SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.dirname(SCRIPT_DIR)) from utils.utils import download_and_unzip_model class TextRecognition: def __init__(self, root_path:str, model_config:dict, jic: bool=True) -> None: self.jic = jic self.root_path = root_path self.model_config = model_config self.model_name = f'{root_path}/{model_config["filename"]}' self.classes = {i+1:v for i,v in enumerate(model_config['classes'])} self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.model = self.__load_model() if jic: self.model = self.__jic_trace(self.model) @staticmethod def __crnn_model(config) -> nn.Module: model = CRNN( img_channel = 1, img_height = config['img_height'], img_width = config['img_width'], num_class = len(config['classes'])+1, map_to_seq_hidden = config['map_to_seq_hidden'], rnn_hidden = config['rnn_hidden'], leaky_relu = config['leaky_relu'] ) return model @staticmethod def __jic_trace(model:nn.Module) -> torch.jit.TracedModule: ''' JIT tracing @params: - model: nn.Module ''' return torch.jit.trace(model, torch.rand(1, 1, 32, 100)) @staticmethod def __check_model(root_path:str, model_config:dict) -> None: if not os.path.isfile(f'{root_path}/{model_config["filename"]}'): download_and_unzip_model( root_dir = root_path, name = model_config['filename'], url = model_config['url'], file_size = model_config['file_size'], unzip = False ) else: print('Load model ...') def __load_model(self) -> nn.Module: ''' Load model from file @return: - model: nn.Module ''' self.__check_model(self.root_path, self.model_config) model = self.__crnn_model(self.model_config) model.load_state_dict(torch.load(self.model_name, map_location=self.device)) model.to(self.device) return model.eval() @staticmethod def __image_transform(image:np.ndarray, height: int=32, width: int=100) -> torch.Tensor: ''' Image transform @params: - image: np.ndarray @return: - image: torch.Tensor ''' image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = cv2.resize(image, (width, height)) image = image.reshape(1, height, width) image = (image / 127.5) - 1.0 image = torch.FloatTensor(image) return image.unsqueeze(0) def recognition( self, image: np.array, decode: str = 'beam_search', beam_size: int = 10 ) -> dict: ''' Recognition text from image @params: - image: np.ndarray - decode: str -> ['beam_search', 'greedy', 'prefix_beam_search'] - beam_size: int -> beam size for beam search @return: - result: dict -> {'text': str, 'confidence': float} ''' assert decode in ['beam_search', 'greedy', 'prefix_beam_search'], 'Decode Failed' image_t = self.__image_transform(image) # recognize with torch.no_grad(): output = self.model(image_t) log_probs = F.log_softmax(output, dim=2) # decode preds = ctc_decode( log_probs, method=decode, beam_size=beam_size, blank=0, label2char=self.classes) # calculate confidence exps = torch.exp(log_probs) try: probs = sum(torch.max(exps, dim=2)[0]/len(exps)).detach().numpy()[0] except RuntimeError: probs = sum(torch.max(exps, dim=2)[0]/len(exps)).cpu().numpy()[0] preds, conf = ''.join(preds[0]), round(probs,2) return {'text': preds, 'confidence': conf} if __name__ == '__main__': import time import string root_path = os.path.expanduser('~/.Halotec/Models') model_config = { 'filename' : 'crnn_008000.pt', 'classes' : string.digits+string.ascii_uppercase+'. ', 'url' : None, 'file_size' : 592694, 'img_height': 32, 'img_width' : 100, 'map_to_seq_hidden': 64, 'rnn_hidden': 256, 'leaky_relu': False } text_recognition = TextRecognition(root_path, model_config, jic=True) image = cv2.imread('./images/12022041114405685_0.jpg') start = time.time() for i in range(10): result = text_recognition.recognition(image, decode='beam_search', beam_size=10) print(result) print(time.time() - start)