Ashu44268 commited on
Commit
586ef47
·
verified ·
1 Parent(s): 680e062

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from torchvision.transforms import InterpolationMode
7
+
8
+ # Device configuration
9
+ device = "cpu"
10
+
11
+ # Load processor (for text tokenization/decoding only)
12
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
13
+
14
+ # Load and prepare the quantized model
15
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
16
+ model = torch.quantization.quantize_dynamic(
17
+ model, {torch.nn.Linear}, dtype=torch.qint8
18
+ )
19
+ model.load_state_dict(torch.load("best_model_int8.pt", map_location="cpu"), strict=False)
20
+ model.to(device)
21
+ model.eval()
22
+
23
+ # Define the EXACT same preprocessing used during training (INFERENCE version)
24
+ # Critical: Must match the training pipeline's resize method (LANCZOS interpolation)
25
+ inference_transform = transforms.Compose([
26
+ # 1. Sharp resizing - same as training (LANCZOS preserves thin strokes)
27
+ transforms.Resize((384, 384), interpolation=InterpolationMode.LANCZOS),
28
+ # 2. Convert to tensor (range [0, 1])
29
+ transforms.ToTensor(),
30
+ ])
31
+
32
+ def predict(img: Image.Image):
33
+ """
34
+ Process image with training-matched preprocessing and run OCR inference.
35
+
36
+ Args:
37
+ img: PIL Image in RGB format
38
+
39
+ Returns:
40
+ Recognized text string
41
+ """
42
+ # Step 1: Ensure image is in RGB mode (consistent with training)
43
+ if img.mode != 'RGB':
44
+ img = img.convert('RGB')
45
+
46
+ # Step 2: Apply the SAME transformation as in training
47
+ # This gives us a tensor in [C, H, W] format, range [0, 1]
48
+ pixel_values = inference_transform(img)
49
+
50
+ # Step 3: Add batch dimension -> [1, C, H, W]
51
+ pixel_values = pixel_values.unsqueeze(0)
52
+ pixel_values = pixel_values.to(device)
53
+
54
+ # Step 4: Run inference
55
+ with torch.no_grad():
56
+ generated_ids = model.generate(pixel_values)
57
+
58
+ # Step 5: Decode the generated token IDs to text
59
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
60
+
61
+ return text
62
+
63
+ # Create Gradio interface
64
+ gr.Interface(
65
+ fn=predict,
66
+ inputs=gr.Image(type="pil", label="Upload word image"),
67
+ outputs=gr.Textbox(label="Recognized Text"),
68
+ title="TrOCR OCR (CPU Optimized)",
69
+ description="Fine-tuned TrOCR on IIIT-5K | CPU inference"
70
+ ).launch(share=True)