Spaces:
Build error
Build error
| """ | |
| @Author : Ali Mustofa HALOTEC | |
| @Module : Character Recognition Neural Network | |
| @Created on : 2 Agust 2022 | |
| """ | |
| #!/usr/bin/env python3 | |
| # Path: src/apps/char_recognition.py | |
| import os | |
| import cv2 | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .crnn import CRNN | |
| from .decoder import ctc_decode | |
| try: | |
| from src.utils.utils import download_and_unzip_model | |
| except ImportError: | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(os.path.dirname(SCRIPT_DIR)) | |
| from utils.utils import download_and_unzip_model | |
| class TextRecognition: | |
| def __init__(self, root_path:str, model_config:dict, jic: bool=True) -> None: | |
| self.jic = jic | |
| self.root_path = root_path | |
| self.model_config = model_config | |
| self.model_name = f'{root_path}/{model_config["filename"]}' | |
| self.classes = {i+1:v for i,v in enumerate(model_config['classes'])} | |
| self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| self.model = self.__load_model() | |
| if jic: self.model = self.__jic_trace(self.model) | |
| def __crnn_model(config) -> nn.Module: | |
| model = CRNN( | |
| img_channel = 1, | |
| img_height = config['img_height'], | |
| img_width = config['img_width'], | |
| num_class = len(config['classes'])+1, | |
| map_to_seq_hidden = config['map_to_seq_hidden'], | |
| rnn_hidden = config['rnn_hidden'], | |
| leaky_relu = config['leaky_relu'] | |
| ) | |
| return model | |
| def __jic_trace(model:nn.Module) -> torch.jit.TracedModule: | |
| ''' | |
| JIT tracing | |
| @params: | |
| - model: nn.Module | |
| ''' | |
| return torch.jit.trace(model, torch.rand(1, 1, 32, 100)) | |
| def __check_model(root_path:str, model_config:dict) -> None: | |
| if not os.path.isfile(f'{root_path}/{model_config["filename"]}'): | |
| download_and_unzip_model( | |
| root_dir = root_path, | |
| name = model_config['filename'], | |
| url = model_config['url'], | |
| file_size = model_config['file_size'], | |
| unzip = False | |
| ) | |
| else: print('Load model ...') | |
| def __load_model(self) -> nn.Module: | |
| ''' | |
| Load model from file | |
| @return: | |
| - model: nn.Module | |
| ''' | |
| self.__check_model(self.root_path, self.model_config) | |
| model = self.__crnn_model(self.model_config) | |
| model.load_state_dict(torch.load(self.model_name, map_location=self.device)) | |
| model.to(self.device) | |
| return model.eval() | |
| def __image_transform(image:np.ndarray, height: int=32, width: int=100) -> torch.Tensor: | |
| ''' | |
| Image transform | |
| @params: | |
| - image: np.ndarray | |
| @return: | |
| - image: torch.Tensor | |
| ''' | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| image = cv2.resize(image, (width, height)) | |
| image = image.reshape(1, height, width) | |
| image = (image / 127.5) - 1.0 | |
| image = torch.FloatTensor(image) | |
| return image.unsqueeze(0) | |
| def recognition( | |
| self, | |
| image: np.array, | |
| decode: str = 'beam_search', | |
| beam_size: int = 10 | |
| ) -> dict: | |
| ''' | |
| Recognition text from image | |
| @params: | |
| - image: np.ndarray | |
| - decode: str -> ['beam_search', 'greedy', 'prefix_beam_search'] | |
| - beam_size: int -> beam size for beam search | |
| @return: | |
| - result: dict -> {'text': str, 'confidence': float} | |
| ''' | |
| assert decode in ['beam_search', 'greedy', 'prefix_beam_search'], 'Decode Failed' | |
| image_t = self.__image_transform(image) | |
| # recognize | |
| with torch.no_grad(): | |
| output = self.model(image_t) | |
| log_probs = F.log_softmax(output, dim=2) | |
| # decode | |
| preds = ctc_decode( | |
| log_probs, method=decode, beam_size=beam_size, | |
| blank=0, label2char=self.classes) | |
| # calculate confidence | |
| exps = torch.exp(log_probs) | |
| try: | |
| probs = sum(torch.max(exps, dim=2)[0]/len(exps)).detach().numpy()[0] | |
| except RuntimeError: | |
| probs = sum(torch.max(exps, dim=2)[0]/len(exps)).cpu().numpy()[0] | |
| preds, conf = ''.join(preds[0]), round(probs,2) | |
| return {'text': preds, 'confidence': conf} | |
| if __name__ == '__main__': | |
| import time | |
| import string | |
| root_path = os.path.expanduser('~/.Halotec/Models') | |
| model_config = { | |
| 'filename' : 'crnn_008000.pt', | |
| 'classes' : string.digits+string.ascii_uppercase+'. ', | |
| 'url' : None, | |
| 'file_size' : 592694, | |
| 'img_height': 32, | |
| 'img_width' : 100, | |
| 'map_to_seq_hidden': 64, | |
| 'rnn_hidden': 256, | |
| 'leaky_relu': False | |
| } | |
| text_recognition = TextRecognition(root_path, model_config, jic=True) | |
| image = cv2.imread('./images/12022041114405685_0.jpg') | |
| start = time.time() | |
| for i in range(10): | |
| result = text_recognition.recognition(image, decode='beam_search', beam_size=10) | |
| print(result) | |
| print(time.time() - start) |