Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from PIL import Image, ImageEnhance, ImageOps | |
| import torch | |
| import re | |
| import easyocr | |
| from io import BytesIO | |
| import numpy as np | |
| # Initialize EasyOCR Reader | |
| reader = easyocr.Reader(['en']) | |
| # Load the pre-trained model and processor for fallback | |
| processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed') | |
| model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-printed') | |
| model.eval() | |
| def enhance_image(image): | |
| image = ImageEnhance.Contrast(image).enhance(2.0) | |
| image = ImageEnhance.Sharpness(image).enhance(2.0) | |
| image = ImageOps.grayscale(image) | |
| return image | |
| def ocr_with_easyocr(pil_img): | |
| # Convert PIL to bytes for EasyOCR | |
| buf = BytesIO() | |
| pil_img.save(buf, format='PNG') | |
| return reader.readtext(buf.getvalue(), detail=0) | |
| def ocr_with_trocr(pil_img): | |
| pixel_values = processor(images=pil_img.convert("RGB"), return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| ids = model.generate(pixel_values) | |
| return processor.batch_decode(ids, skip_special_tokens=True)[0] | |
| def extract_meter_reading(image): | |
| try: | |
| pil = Image.fromarray(image) | |
| w, h = pil.size | |
| # Define regions: reading top 40%, serial bottom 40% | |
| top_region = pil.crop((0, 0, w, int(h * 0.4))) | |
| bottom_region = pil.crop((0, int(h * 0.5), w, h)) | |
| # Enhance regions | |
| top_enh = enhance_image(top_region) | |
| bot_enh = enhance_image(bottom_region) | |
| # OCR regions | |
| top_texts = ocr_with_easyocr(top_enh) + [ocr_with_trocr(top_enh)] | |
| bot_texts = ocr_with_easyocr(bot_enh) + [ocr_with_trocr(bot_enh)] | |
| top_combined = " ".join(top_texts) | |
| bot_combined = " ".join(bot_texts) | |
| # Extract reading: look for digits near kwh | |
| reading_match = re.search(r"(\d+)\s*(?=kwh|kw h|k w h)", top_combined, re.IGNORECASE) | |
| if not reading_match: | |
| # fallback: any 4-7 digit number | |
| fallback = re.findall(r"\b\d{4,7}\b", top_combined) | |
| reading = fallback[0] if fallback else "Not Found" | |
| else: | |
| reading = reading_match.group(1) | |
| # Extract serial: longest digit sequence in bottom | |
| serials = re.findall(r"\d{6,12}", bot_combined) | |
| serial = max(serials, key=len) if serials else "Not Found" | |
| return f"Serial Number: {serial}\nMeter Reading: {reading}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Gradio app | |
| def main(): | |
| iface = gr.Interface( | |
| fn=extract_meter_reading, | |
| inputs=gr.Image(type="numpy", label="Upload or Capture Meter Image"), | |
| outputs="text", | |
| title="Meter Reading and Serial Number Extractor", | |
| description="Upload a meter image; extracts serial number and meter reading using region-based OCR." | |
| ) | |
| iface.launch() | |
| if __name__ == '__main__': | |
| main() |