|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms import InterpolationMode |
|
|
|
|
|
|
|
|
device = "cpu" |
|
|
|
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed") |
|
|
|
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed") |
|
|
model = torch.quantization.quantize_dynamic( |
|
|
model, {torch.nn.Linear}, dtype=torch.qint8 |
|
|
) |
|
|
model.load_state_dict(torch.load("best_model_int8.pt", map_location="cpu"), strict=False) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
inference_transform = transforms.Compose([ |
|
|
|
|
|
transforms.Resize((384, 384), interpolation=InterpolationMode.LANCZOS), |
|
|
|
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
def predict(img: Image.Image): |
|
|
""" |
|
|
Process image with training-matched preprocessing and run OCR inference. |
|
|
|
|
|
Args: |
|
|
img: PIL Image in RGB format |
|
|
|
|
|
Returns: |
|
|
Recognized text string |
|
|
""" |
|
|
|
|
|
if img.mode != 'RGB': |
|
|
img = img.convert('RGB') |
|
|
|
|
|
|
|
|
|
|
|
pixel_values = inference_transform(img) |
|
|
|
|
|
|
|
|
pixel_values = pixel_values.unsqueeze(0) |
|
|
pixel_values = pixel_values.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate(pixel_values) |
|
|
|
|
|
|
|
|
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil", label="Upload word image"), |
|
|
outputs=gr.Textbox(label="Recognized Text"), |
|
|
title="TrOCR OCR (CPU Optimized)", |
|
|
description="Fine-tuned TrOCR on IIIT-5K | CPU inference" |
|
|
).launch(share=True) |