reader / app.py
Shivam3002's picture
Update app.py
af5ce57 verified
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image, ImageEnhance, ImageOps
import torch
import re
import easyocr
from io import BytesIO
import numpy as np
# Initialize EasyOCR Reader
reader = easyocr.Reader(['en'])
# Load the pre-trained model and processor for fallback
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-printed')
model.eval()
def enhance_image(image):
image = ImageEnhance.Contrast(image).enhance(2.0)
image = ImageEnhance.Sharpness(image).enhance(2.0)
image = ImageOps.grayscale(image)
return image
def ocr_with_easyocr(pil_img):
# Convert PIL to bytes for EasyOCR
buf = BytesIO()
pil_img.save(buf, format='PNG')
return reader.readtext(buf.getvalue(), detail=0)
def ocr_with_trocr(pil_img):
pixel_values = processor(images=pil_img.convert("RGB"), return_tensors="pt").pixel_values
with torch.no_grad():
ids = model.generate(pixel_values)
return processor.batch_decode(ids, skip_special_tokens=True)[0]
def extract_meter_reading(image):
try:
pil = Image.fromarray(image)
w, h = pil.size
# Define regions: reading top 40%, serial bottom 40%
top_region = pil.crop((0, 0, w, int(h * 0.4)))
bottom_region = pil.crop((0, int(h * 0.5), w, h))
# Enhance regions
top_enh = enhance_image(top_region)
bot_enh = enhance_image(bottom_region)
# OCR regions
top_texts = ocr_with_easyocr(top_enh) + [ocr_with_trocr(top_enh)]
bot_texts = ocr_with_easyocr(bot_enh) + [ocr_with_trocr(bot_enh)]
top_combined = " ".join(top_texts)
bot_combined = " ".join(bot_texts)
# Extract reading: look for digits near kwh
reading_match = re.search(r"(\d+)\s*(?=kwh|kw h|k w h)", top_combined, re.IGNORECASE)
if not reading_match:
# fallback: any 4-7 digit number
fallback = re.findall(r"\b\d{4,7}\b", top_combined)
reading = fallback[0] if fallback else "Not Found"
else:
reading = reading_match.group(1)
# Extract serial: longest digit sequence in bottom
serials = re.findall(r"\d{6,12}", bot_combined)
serial = max(serials, key=len) if serials else "Not Found"
return f"Serial Number: {serial}\nMeter Reading: {reading}"
except Exception as e:
return f"Error: {str(e)}"
# Gradio app
def main():
iface = gr.Interface(
fn=extract_meter_reading,
inputs=gr.Image(type="numpy", label="Upload or Capture Meter Image"),
outputs="text",
title="Meter Reading and Serial Number Extractor",
description="Upload a meter image; extracts serial number and meter reading using region-based OCR."
)
iface.launch()
if __name__ == '__main__':
main()