File size: 4,593 Bytes
7ee7e3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
@Author     : Ali Mustofa HALOTEC
@Module     : Character Recognition Neural Network
@Created on : 2 Agust 2022
"""
#!/usr/bin/env python3
# Path: src/apps/char_recognition.py

import os
import cv2
import sys
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from .crnn import CRNN
from .decoder import ctc_decode

try:
	from src.utils.utils import download_and_unzip_model
except ImportError:
	SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
	sys.path.append(os.path.dirname(SCRIPT_DIR))
	from utils.utils import download_and_unzip_model

class TextRecognition:
	def __init__(self, root_path:str, model_config:dict, jic: bool=True) -> None:
		self.jic            = jic
		self.root_path      = root_path
		self.model_config   = model_config
		self.model_name     = f'{root_path}/{model_config["filename"]}'
		self.classes        = {i+1:v for i,v in enumerate(model_config['classes'])}
		self.device         = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
		self.model          = self.__load_model()
		if jic: self.model  = self.__jic_trace(self.model)
		
	@staticmethod
	def __crnn_model(config) -> nn.Module:
		model = CRNN(
			img_channel = 1,
			img_height  = config['img_height'],
			img_width   = config['img_width'],
			num_class   = len(config['classes'])+1,
			map_to_seq_hidden = config['map_to_seq_hidden'],
			rnn_hidden  = config['rnn_hidden'],
			leaky_relu  = config['leaky_relu']
		)
		return model

	@staticmethod
	def __jic_trace(model:nn.Module) -> torch.jit.TracedModule:
		'''
		JIT tracing
		@params:
			- model: nn.Module
		'''
		return torch.jit.trace(model, torch.rand(1, 1, 32, 100))

	@staticmethod
	def __check_model(root_path:str, model_config:dict) -> None:
		if not os.path.isfile(f'{root_path}/{model_config["filename"]}'):
			download_and_unzip_model(
				root_dir    = root_path,
				name        = model_config['filename'],
				url         = model_config['url'],
				file_size   = model_config['file_size'],
				unzip       = False
			)
		else: print('Load model ...')

	def __load_model(self) -> nn.Module:
		'''
		Load model from file
		@return:
			- model: nn.Module
		'''
		self.__check_model(self.root_path, self.model_config)
		model = self.__crnn_model(self.model_config)
		model.load_state_dict(torch.load(self.model_name, map_location=self.device))
		model.to(self.device)
		return model.eval()

	@staticmethod
	def __image_transform(image:np.ndarray, height: int=32, width: int=100) -> torch.Tensor:
		'''
		Image transform
		@params:
			- image: np.ndarray
		@return:
			- image: torch.Tensor
		'''
		image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
		image = cv2.resize(image, (width, height))
		image = image.reshape(1, height, width)
		image = (image / 127.5) - 1.0
		image = torch.FloatTensor(image)
		return image.unsqueeze(0)

	def recognition(
			self,
			image: np.array, 
			decode: str = 'beam_search',
			beam_size: int = 10
		) -> dict:
		'''
		Recognition text from image
		@params:
			- image: np.ndarray
			- decode: str -> ['beam_search', 'greedy', 'prefix_beam_search']
			- beam_size: int -> beam size for beam search
		@return:
			- result: dict -> {'text': str, 'confidence': float}
		'''

		assert decode in ['beam_search', 'greedy', 'prefix_beam_search'], 'Decode Failed'

		image_t = self.__image_transform(image)
		# recognize
		with torch.no_grad():
			output = self.model(image_t)
		log_probs = F.log_softmax(output, dim=2)
		# decode
		preds = ctc_decode(
			log_probs, method=decode, beam_size=beam_size,
			blank=0, label2char=self.classes)
		# calculate confidence
		exps = torch.exp(log_probs)
		try:
			probs = sum(torch.max(exps, dim=2)[0]/len(exps)).detach().numpy()[0]
		except RuntimeError:
			probs = sum(torch.max(exps, dim=2)[0]/len(exps)).cpu().numpy()[0]
		
		preds, conf = ''.join(preds[0]), round(probs,2)
		return {'text': preds, 'confidence': conf}


if __name__ == '__main__':
	import time
	import string

	root_path = os.path.expanduser('~/.Halotec/Models')

	model_config = {
		'filename'  : 'crnn_008000.pt',
		'classes'   : string.digits+string.ascii_uppercase+'. ',
		'url'       : None,
		'file_size' : 592694,
		'img_height': 32,
		'img_width' : 100,
		'map_to_seq_hidden': 64,
		'rnn_hidden': 256,
		'leaky_relu': False
	}
	text_recognition = TextRecognition(root_path, model_config, jic=True)
	image = cv2.imread('./images/12022041114405685_0.jpg')
	start = time.time()
	for i in range(10):
		result = text_recognition.recognition(image, decode='beam_search', beam_size=10)
		print(result)
	print(time.time() - start)