| | import os |
| | import gradio as gr |
| | import omegaconf |
| | import torch |
| | import numpy |
| |
|
| | import easyocr |
| | from PIL import Image |
| |
|
| | from vietocr.model.transformerocr import VietOCR |
| | from vietocr.model.vocab import Vocab |
| | from vietocr.translate import translate, process_input |
| |
|
| | reader = easyocr.Reader(['vi']) |
| |
|
| | examples_data = os.listdir('examples') |
| | examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data] |
| |
|
| | config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml") |
| | config = omegaconf.OmegaConf.to_container(config, resolve=True) |
| |
|
| | vocab = Vocab(config['vocab']) |
| | model = VietOCR(len(vocab), |
| | config['backbone'], |
| | config['cnn'], |
| | config['transformer'], |
| | config['seq_modeling']) |
| | model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu'))) |
| | def viet_ocr_predict(inp): |
| | img = process_input(inp, config['dataset']['image_height'], |
| | config['dataset']['image_min_width'], config['dataset']['image_max_width']) |
| | out = translate(img, model)[0].tolist() |
| | out = vocab.decode(out) |
| | return out |
| | def predict(filepath): |
| | bounds = reader.readtext(filepath) |
| | im = Image.open(filepath) |
| | inp = numpy.asarray(im) |
| |
|
| | |
| |
|
| | width, height, _ = inp.shape |
| | if width>height: |
| | height, width, _ = inp.shape |
| |
|
| | texts='' |
| | for (bbox, text, prob) in bounds: |
| | (tl, tr, br, bl) = bbox |
| | tl = (int(tl[0]), int(tl[1])) |
| | tr = (int(tr[0]), int(tr[1])) |
| | br = (int(br[0]), int(br[1])) |
| | bl = (int(bl[0]), int(bl[1])) |
| |
|
| | min_x = min(tl[0], tr[0], br[0], bl[0]) |
| | min_x = max(0, min_x) |
| | max_x = max(tl[0], tr[0], br[0], bl[0]) |
| | max_x = min(width-1, max_x) |
| | min_y = min(tl[1], tr[1], br[1], bl[1]) |
| | min_y = max(0, min_y) |
| | max_y = max(tl[1], tr[1], br[1], bl[1]) |
| | max_y = min(height-1, max_y) |
| | |
| | try: |
| | cropped_image = inp[min_y:max_y,min_x:max_x,:] |
| | cropped_image = Image.fromarray(cropped_image) |
| | out = viet_ocr_predict(cropped_image) |
| | except: |
| | out = text |
| | print(out) |
| | texts = texts + '\t' + out |
| | |
| | return texts |
| |
|
| | gr.Interface(fn=predict, |
| | title='Vietnamese Handwriting Recognition', |
| | inputs=gr.Image(type='filepath'), |
| | outputs=gr.Text(), |
| | |
| | ).launch() |