import torch from PIL import Image from strhub.data.module import SceneTextDataModule from strhub.models.utils import load_from_checkpoint class TextRecognizer: def __init__(self, ckpt_path, device='cpu'): self.device = device self.str = load_from_checkpoint(ckpt_path).eval().to(device) self.img_transform = SceneTextDataModule.get_transform(self.str.hparams.img_size) def recognize(self, image_pil): image_tensor = self.img_transform(image_pil).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.str(image_tensor) pred = logits.softmax(-1) label, _ = self.str.tokenizer.decode(pred) return label[0]