OCR_AEB_Serial_Number / STR_recognize.py
Ehtesham123's picture
Upload 2 files
5b19d10 verified
raw
history blame contribute delete
709 Bytes
import torch
from PIL import Image
from strhub.data.module import SceneTextDataModule
from strhub.models.utils import load_from_checkpoint
class TextRecognizer:
def __init__(self, ckpt_path, device='cpu'):
self.device = device
self.str = load_from_checkpoint(ckpt_path).eval().to(device)
self.img_transform = SceneTextDataModule.get_transform(self.str.hparams.img_size)
def recognize(self, image_pil):
image_tensor = self.img_transform(image_pil).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.str(image_tensor)
pred = logits.softmax(-1)
label, _ = self.str.tokenizer.decode(pred)
return label[0]