| | import torch.nn as nn |
| | from util.util import to_device |
| | from torch.nn import init |
| | import os |
| | import torch |
| | from .networks import * |
| | from params import * |
| |
|
| | class BidirectionalLSTM(nn.Module): |
| |
|
| | def __init__(self, nIn, nHidden, nOut): |
| | super(BidirectionalLSTM, self).__init__() |
| |
|
| | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) |
| | self.embedding = nn.Linear(nHidden * 2, nOut) |
| |
|
| |
|
| | def forward(self, input): |
| | recurrent, _ = self.rnn(input) |
| | T, b, h = recurrent.size() |
| | t_rec = recurrent.view(T * b, h) |
| |
|
| | output = self.embedding(t_rec) |
| | output = output.view(T, b, -1) |
| |
|
| | return output |
| |
|
| |
|
| | class CRNN(nn.Module): |
| |
|
| | def __init__(self, leakyRelu=False): |
| | super(CRNN, self).__init__() |
| | self.name = 'OCR' |
| | |
| |
|
| | ks = [3, 3, 3, 3, 3, 3, 2] |
| | ps = [1, 1, 1, 1, 1, 1, 0] |
| | ss = [1, 1, 1, 1, 1, 1, 1] |
| | nm = [64, 128, 256, 256, 512, 512, 512] |
| |
|
| | cnn = nn.Sequential() |
| | nh = 256 |
| | dealwith_lossnone=False |
| |
|
| | def convRelu(i, batchNormalization=False): |
| | nIn = 1 if i == 0 else nm[i - 1] |
| | nOut = nm[i] |
| | cnn.add_module('conv{0}'.format(i), |
| | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) |
| | if batchNormalization: |
| | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) |
| | if leakyRelu: |
| | cnn.add_module('relu{0}'.format(i), |
| | nn.LeakyReLU(0.2, inplace=True)) |
| | else: |
| | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) |
| |
|
| | convRelu(0) |
| | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) |
| | convRelu(1) |
| | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) |
| | convRelu(2, True) |
| | convRelu(3) |
| | cnn.add_module('pooling{0}'.format(2), |
| | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) |
| | convRelu(4, True) |
| | if resolution==63: |
| | cnn.add_module('pooling{0}'.format(3), |
| | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) |
| | convRelu(5) |
| | cnn.add_module('pooling{0}'.format(4), |
| | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) |
| | convRelu(6, True) |
| |
|
| | self.cnn = cnn |
| | self.use_rnn = False |
| | if self.use_rnn: |
| | self.rnn = nn.Sequential( |
| | BidirectionalLSTM(512, nh, nh), |
| | BidirectionalLSTM(nh, nh, )) |
| | else: |
| | self.linear = nn.Linear(512, VOCAB_SIZE) |
| |
|
| | |
| | if dealwith_lossnone: |
| | self.register_backward_hook(self.backward_hook) |
| |
|
| | self.device = torch.device('cuda:{}'.format(0)) |
| | self.init = 'N02' |
| | |
| | |
| | self = init_weights(self, self.init) |
| |
|
| | def forward(self, input): |
| | |
| | conv = self.cnn(input) |
| | b, c, h, w = conv.size() |
| | if h!=1: |
| | print('a') |
| | assert h == 1, "the height of conv must be 1" |
| | conv = conv.squeeze(2) |
| | conv = conv.permute(2, 0, 1) |
| |
|
| | if self.use_rnn: |
| | |
| | output = self.rnn(conv) |
| | else: |
| | output = self.linear(conv) |
| | return output |
| |
|
| | def backward_hook(self, module, grad_input, grad_output): |
| | for g in grad_input: |
| | g[g != g] = 0 |
| |
|
| |
|
| | class OCRLabelConverter(object): |
| | """Convert between str and label. |
| | |
| | NOTE: |
| | Insert `blank` to the alphabet for CTC. |
| | |
| | Args: |
| | alphabet (str): set of the possible characters. |
| | ignore_case (bool, default=True): whether or not to ignore all of the case. |
| | """ |
| |
|
| | def __init__(self, alphabet, ignore_case=False): |
| | self._ignore_case = ignore_case |
| | if self._ignore_case: |
| | alphabet = alphabet.lower() |
| | self.alphabet = alphabet + '-' |
| |
|
| | self.dict = {} |
| | for i, char in enumerate(alphabet): |
| | |
| | self.dict[char] = i + 1 |
| |
|
| | def encode(self, text): |
| | """Support batch or single str. |
| | |
| | Args: |
| | text (str or list of str): texts to convert. |
| | |
| | Returns: |
| | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. |
| | torch.IntTensor [n]: length of each text. |
| | """ |
| | ''' |
| | if isinstance(text, str): |
| | text = [ |
| | self.dict[char.lower() if self._ignore_case else char] |
| | for char in text |
| | ] |
| | length = [len(text)] |
| | elif isinstance(text, collections.Iterable): |
| | length = [len(s) for s in text] |
| | text = ''.join(text) |
| | text, _ = self.encode(text) |
| | return (torch.IntTensor(text), torch.IntTensor(length)) |
| | ''' |
| | length = [] |
| | result = [] |
| | for item in text: |
| | item = item.decode('utf-8', 'strict') |
| | length.append(len(item)) |
| | for char in item: |
| | index = self.dict[char] |
| | result.append(index) |
| |
|
| | text = result |
| | return (torch.IntTensor(text), torch.IntTensor(length)) |
| |
|
| | def decode(self, t, length, raw=False): |
| | """Decode encoded texts back into strs. |
| | |
| | Args: |
| | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. |
| | torch.IntTensor [n]: length of each text. |
| | |
| | Raises: |
| | AssertionError: when the texts and its length does not match. |
| | |
| | Returns: |
| | text (str or list of str): texts to convert. |
| | """ |
| | if length.numel() == 1: |
| | length = length[0] |
| | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), |
| | length) |
| | if raw: |
| | return ''.join([self.alphabet[i - 1] for i in t]) |
| | else: |
| | char_list = [] |
| | for i in range(length): |
| | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): |
| | char_list.append(self.alphabet[t[i] - 1]) |
| | return ''.join(char_list) |
| | else: |
| | |
| | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( |
| | t.numel(), length.sum()) |
| | texts = [] |
| | index = 0 |
| | for i in range(length.numel()): |
| | l = length[i] |
| | texts.append( |
| | self.decode( |
| | t[index:index + l], torch.IntTensor([l]), raw=raw)) |
| | index += l |
| | return texts |
| |
|
| |
|
| | class strLabelConverter(object): |
| | """Convert between str and label. |
| | NOTE: |
| | Insert `blank` to the alphabet for CTC. |
| | Args: |
| | alphabet (str): set of the possible characters. |
| | ignore_case (bool, default=True): whether or not to ignore all of the case. |
| | """ |
| |
|
| | def __init__(self, alphabet, ignore_case=False): |
| | self._ignore_case = ignore_case |
| | if self._ignore_case: |
| | alphabet = alphabet.lower() |
| | self.alphabet = alphabet + '-' |
| |
|
| | self.dict = {} |
| | for i, char in enumerate(alphabet): |
| | |
| | self.dict[char] = i + 1 |
| |
|
| | def encode(self, text): |
| | """Support batch or single str. |
| | Args: |
| | text (str or list of str): texts to convert. |
| | Returns: |
| | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. |
| | torch.IntTensor [n]: length of each text. |
| | """ |
| | ''' |
| | if isinstance(text, str): |
| | text = [ |
| | self.dict[char.lower() if self._ignore_case else char] |
| | for char in text |
| | ] |
| | length = [len(text)] |
| | elif isinstance(text, collections.Iterable): |
| | length = [len(s) for s in text] |
| | text = ''.join(text) |
| | text, _ = self.encode(text) |
| | return (torch.IntTensor(text), torch.IntTensor(length)) |
| | ''' |
| | length = [] |
| | result = [] |
| | results = [] |
| | for item in text: |
| | item = item.decode('utf-8', 'strict') |
| | length.append(len(item)) |
| | for char in item: |
| | index = self.dict[char] |
| | result.append(index) |
| | results.append(result) |
| | result = [] |
| |
|
| | return (torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length)) |
| |
|
| | def decode(self, t, length, raw=False): |
| | """Decode encoded texts back into strs. |
| | Args: |
| | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. |
| | torch.IntTensor [n]: length of each text. |
| | Raises: |
| | AssertionError: when the texts and its length does not match. |
| | Returns: |
| | text (str or list of str): texts to convert. |
| | """ |
| | if length.numel() == 1: |
| | length = length[0] |
| | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), |
| | length) |
| | if raw: |
| | return ''.join([self.alphabet[i - 1] for i in t]) |
| | else: |
| | char_list = [] |
| | for i in range(length): |
| | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): |
| | char_list.append(self.alphabet[t[i] - 1]) |
| | return ''.join(char_list) |
| | else: |
| | |
| | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( |
| | t.numel(), length.sum()) |
| | texts = [] |
| | index = 0 |
| | for i in range(length.numel()): |
| | l = length[i] |
| | texts.append( |
| | self.decode( |
| | t[index:index + l], torch.IntTensor([l]), raw=raw)) |
| | index += l |
| | return texts |
| |
|
| |
|
| |
|