import torch from PIL import Image from models.base import OCRModel from transformers import AutoModelForVision2Seq, AutoProcessor, pipeline class TrOCROCR(OCRModel): """ TrOCR implementation using Hugging Face """ def __init__( self, model_name: str = "microsoft/trocr-base-handwritten", device: str | None = None, ): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForVision2Seq.from_pretrained(model_name) # self.model = AutoModelForVision2Seq.from_pretrained(model_name, torch_dtype=torch.float16) self.model.to(self.device) self.model.eval() @torch.no_grad() def predict(self, image: Image.Image) -> str: # image = preprocess(image) pixel_values = self.processor( images=image, return_tensors="pt" ).pixel_values.to(self.device) generated_ids = self.model.generate(pixel_values) text = self.processor.batch_decode( generated_ids, skip_special_tokens=True )[0] return text.strip()