Spaces:
Runtime error
Runtime error
| from argparse import ArgumentParser | |
| from itertools import groupby | |
| import os | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| import utils_ | |
| class BidirectionalLSTM(nn.Module): | |
| def __init__(self, nIn, nHidden, nOut): | |
| super(BidirectionalLSTM, self).__init__() | |
| self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) | |
| self.embedding = nn.Linear(nHidden * 2, nOut) | |
| def forward(self, input): | |
| recurrent, _ = self.rnn(input) | |
| T, b, h = recurrent.size() | |
| t_rec = recurrent.view(T * b, h) | |
| output = self.embedding(t_rec) # [T * b, nOut] | |
| output = output.view(T, b, -1) | |
| return output | |
| class CRNN(nn.Module): | |
| def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False): | |
| super(CRNN, self).__init__() | |
| assert imgH % 16 == 0, "imgH has to be a multiple of 16" | |
| ks = [3, 3, 3, 3, 3, 3, 2] | |
| ps = [1, 1, 1, 1, 1, 1, 0] | |
| ss = [1, 1, 1, 1, 1, 1, 1] | |
| nm = [64, 128, 256, 256, 512, 512, 512] | |
| cnn = nn.Sequential() | |
| def convRelu(i, batchNormalization=False): | |
| nIn = nc if i == 0 else nm[i - 1] | |
| nOut = nm[i] | |
| cnn.add_module( | |
| "conv{0}".format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]) | |
| ) | |
| if batchNormalization: | |
| cnn.add_module("batchnorm{0}".format(i), nn.BatchNorm2d(nOut)) | |
| if leakyRelu: | |
| cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2, inplace=True)) | |
| else: | |
| cnn.add_module("relu{0}".format(i), nn.ReLU(True)) | |
| convRelu(0) | |
| cnn.add_module("pooling{0}".format(0), nn.MaxPool2d(2, 2)) # 64x16x64 | |
| convRelu(1) | |
| cnn.add_module("pooling{0}".format(1), nn.MaxPool2d(2, 2)) # 128x8x32 | |
| convRelu(2, True) | |
| convRelu(3) | |
| cnn.add_module( | |
| "pooling{0}".format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1)) | |
| ) # 256x4x16 | |
| convRelu(4, True) | |
| convRelu(5) | |
| cnn.add_module( | |
| "pooling{0}".format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1)) | |
| ) # 512x2x16 | |
| convRelu(6, True) # 512x1x16 | |
| self.cnn = cnn | |
| self.rnn = nn.Sequential( | |
| BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass) | |
| ) | |
| def forward(self, input): | |
| # conv features | |
| conv = self.cnn(input) | |
| b, c, h, w = conv.size() | |
| assert h == 1, "the height of conv must be 1" | |
| conv = conv.squeeze(2) | |
| conv = conv.permute(2, 0, 1) # [w, b, c] | |
| # rnn features | |
| output = self.rnn(conv) | |
| return output | |
| VOCAB = [ | |
| "BLANK", | |
| "Z", | |
| "B", | |
| "4", | |
| "X", | |
| "R", | |
| "2", | |
| "U", | |
| "D", | |
| "G", | |
| "Q", | |
| "S", | |
| "A", | |
| "N", | |
| "K", | |
| "0", | |
| "C", | |
| "J", | |
| "P", | |
| "Y", | |
| "H", | |
| "7", | |
| "W", | |
| "V", | |
| "5", | |
| "F", | |
| "L", | |
| "8", | |
| "1", | |
| "I", | |
| "T", | |
| "M", | |
| "3", | |
| "O", | |
| "9", | |
| "E", | |
| "6", | |
| ] | |
| def add_text(image, text, pos): | |
| xmin, ymin, xmax, ymax = pos | |
| image = cv2.putText( | |
| image, | |
| text, | |
| (xmin, ymin - 15), | |
| cv2.FONT_HERSHEY_COMPLEX, | |
| 0.85, | |
| (0, 0, 255), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| return image | |
| def greedy_decode(preds): | |
| # collapse best path (using itertools.groupby), map to chars, join char list to string | |
| best_chars_collapsed = [k for k, _ in groupby(preds) if k != "BLANK"] | |
| res = "".join(best_chars_collapsed) | |
| return res | |
| def read_image(file): | |
| img = cv2.imread(file) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return img | |
| def idx2char(preds): | |
| return [VOCAB[idx] for idx in preds] | |
| def post_process(preds): | |
| # preds shape (seq_len, num_class) | |
| _, preds = torch.max(preds, dim=1) | |
| return idx2char(preds.tolist()) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Grayscale(), | |
| transforms.Resize((32, 128)), | |
| transforms.Normalize(0.5, 0.5), | |
| ] | |
| ) | |
| model = CRNN(32, 1, 37, 512) | |
| state = torch.load("./out/ocr_point08.pt") | |
| model.load_state_dict(state["model"]) | |
| def recognize(image): | |
| model.eval() | |
| preds = model(transform(image).unsqueeze(0)) | |
| text = post_process(preds[:, 0, :]) | |
| text = greedy_decode(text) | |
| return text | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--image", | |
| default=None, | |
| type=str, | |
| help="path to image on which prediction will be made", | |
| ) | |
| args = parser.parse_args() | |
| assert os.path.exists(args.image), f"given path {args.image} does not exists" | |
| im = read_image(args.image) | |
| text = recognize(im) | |