import gradio as gr from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image, ImageEnhance, ImageOps import torch import re import easyocr from io import BytesIO import numpy as np # Initialize EasyOCR Reader reader = easyocr.Reader(['en']) # Load the pre-trained model and processor for fallback processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed') model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-printed') model.eval() def enhance_image(image): image = ImageEnhance.Contrast(image).enhance(2.0) image = ImageEnhance.Sharpness(image).enhance(2.0) image = ImageOps.grayscale(image) return image def ocr_with_easyocr(pil_img): # Convert PIL to bytes for EasyOCR buf = BytesIO() pil_img.save(buf, format='PNG') return reader.readtext(buf.getvalue(), detail=0) def ocr_with_trocr(pil_img): pixel_values = processor(images=pil_img.convert("RGB"), return_tensors="pt").pixel_values with torch.no_grad(): ids = model.generate(pixel_values) return processor.batch_decode(ids, skip_special_tokens=True)[0] def extract_meter_reading(image): try: pil = Image.fromarray(image) w, h = pil.size # Define regions: reading top 40%, serial bottom 40% top_region = pil.crop((0, 0, w, int(h * 0.4))) bottom_region = pil.crop((0, int(h * 0.5), w, h)) # Enhance regions top_enh = enhance_image(top_region) bot_enh = enhance_image(bottom_region) # OCR regions top_texts = ocr_with_easyocr(top_enh) + [ocr_with_trocr(top_enh)] bot_texts = ocr_with_easyocr(bot_enh) + [ocr_with_trocr(bot_enh)] top_combined = " ".join(top_texts) bot_combined = " ".join(bot_texts) # Extract reading: look for digits near kwh reading_match = re.search(r"(\d+)\s*(?=kwh|kw h|k w h)", top_combined, re.IGNORECASE) if not reading_match: # fallback: any 4-7 digit number fallback = re.findall(r"\b\d{4,7}\b", top_combined) reading = fallback[0] if fallback else "Not Found" else: reading = reading_match.group(1) # Extract serial: longest digit sequence in bottom serials = re.findall(r"\d{6,12}", bot_combined) serial = max(serials, key=len) if serials else "Not Found" return f"Serial Number: {serial}\nMeter Reading: {reading}" except Exception as e: return f"Error: {str(e)}" # Gradio app def main(): iface = gr.Interface( fn=extract_meter_reading, inputs=gr.Image(type="numpy", label="Upload or Capture Meter Image"), outputs="text", title="Meter Reading and Serial Number Extractor", description="Upload a meter image; extracts serial number and meter reading using region-based OCR." ) iface.launch() if __name__ == '__main__': main()