import torch import gradio as gr from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image from torchvision import transforms from torchvision.transforms import InterpolationMode # Device configuration device = "cpu" # Load processor (for text tokenization/decoding only) processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed") # Load and prepare the quantized model model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed") model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) model.load_state_dict(torch.load("best_model_int8.pt", map_location="cpu"), strict=False) model.to(device) model.eval() # Define the EXACT same preprocessing used during training (INFERENCE version) # Critical: Must match the training pipeline's resize method (LANCZOS interpolation) inference_transform = transforms.Compose([ # 1. Sharp resizing - same as training (LANCZOS preserves thin strokes) transforms.Resize((384, 384), interpolation=InterpolationMode.LANCZOS), # 2. Convert to tensor (range [0, 1]) transforms.ToTensor(), ]) def predict(img: Image.Image): """ Process image with training-matched preprocessing and run OCR inference. Args: img: PIL Image in RGB format Returns: Recognized text string """ # Step 1: Ensure image is in RGB mode (consistent with training) if img.mode != 'RGB': img = img.convert('RGB') # Step 2: Apply the SAME transformation as in training # This gives us a tensor in [C, H, W] format, range [0, 1] pixel_values = inference_transform(img) # Step 3: Add batch dimension -> [1, C, H, W] pixel_values = pixel_values.unsqueeze(0) pixel_values = pixel_values.to(device) # Step 4: Run inference with torch.no_grad(): generated_ids = model.generate(pixel_values) # Step 5: Decode the generated token IDs to text text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return text # Create Gradio interface gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload word image"), outputs=gr.Textbox(label="Recognized Text"), title="TrOCR OCR (CPU Optimized)", description="Fine-tuned TrOCR on IIIT-5K | CPU inference" ).launch(share=True)