File size: 709 Bytes
5b19d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from PIL import Image
from strhub.data.module import SceneTextDataModule
from strhub.models.utils import load_from_checkpoint

class TextRecognizer:
    def __init__(self, ckpt_path, device='cpu'):
        self.device = device
        self.str = load_from_checkpoint(ckpt_path).eval().to(device)
        self.img_transform = SceneTextDataModule.get_transform(self.str.hparams.img_size)

    def recognize(self, image_pil):
        image_tensor = self.img_transform(image_pil).unsqueeze(0).to(self.device)
        with torch.no_grad():
            logits = self.str(image_tensor)
            pred = logits.softmax(-1)
            label, _ = self.str.tokenizer.decode(pred)
        return label[0]