| | import torch |
| | from PIL import Image |
| | from models.base import OCRModel |
| | from transformers import AutoModelForVision2Seq, AutoProcessor, pipeline |
| |
|
| | class TrOCROCR(OCRModel): |
| | """ |
| | TrOCR implementation using Hugging Face |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model_name: str = "microsoft/trocr-base-handwritten", |
| | device: str | None = None, |
| | ): |
| | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | self.processor = AutoProcessor.from_pretrained(model_name) |
| | self.model = AutoModelForVision2Seq.from_pretrained(model_name) |
| | |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | @torch.no_grad() |
| | def predict(self, image: Image.Image) -> str: |
| | |
| |
|
| | pixel_values = self.processor( |
| | images=image, |
| | return_tensors="pt" |
| | ).pixel_values.to(self.device) |
| |
|
| | generated_ids = self.model.generate(pixel_values) |
| | text = self.processor.batch_decode( |
| | generated_ids, |
| | skip_special_tokens=True |
| | )[0] |
| |
|
| | return text.strip() |