trocr_htr / app.py
Partha11's picture
Update app.py (#1)
5b86184 verified
import spaces
import gradio as gr
from models.unet import MaskPredictor
from models.trocr import WordPredictor
from utils.image import ImageUtils
mask_predictor = MaskPredictor()
word_predictor = WordPredictor()
image_utils = ImageUtils()
@spaces.GPU(duration=120)
def predict(image):
image = image_utils.to_cv2(image)
sample, scale, top, left, (orig_h, orig_w) = image_utils.resize(image, target_size=768, is_mask=False)
sample = image_utils.normalize(sample)
mask = mask_predictor.predict(sample)
mask = image_utils.restore_size(mask, scale, top, left, orig_h, orig_w)
words = image_utils.extract_masks(image, mask)
ordered_words, _ = image_utils.order_words(words, vertical_padding_factor=1.2, height_mode='median')
text = ""
for word in ordered_words:
prediction = word_predictor.predict(word)
text += prediction + " "
return text.strip()
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Image"),
outputs=gr.Textbox(lines = 20, label="Text"),
title="HTR",
description="..."
)
iface.launch()