from __future__ import annotations import io import os import time from functools import lru_cache from typing import Optional import cv2 import gradio as gr import numpy as np import torch from peft import PeftModel from PIL import Image from transformers import AutoTokenizer, TrOCRProcessor, ViTImageProcessor, VisionEncoderDecoderModel BASE_MODEL = os.getenv("TROCR_BASE_MODEL", "paudelanil/trocr-devanagari-2") ADAPTER_MODEL = os.getenv("TROCR_ADAPTER_MODEL", "waglesameer5/devgen-trocr-devanagari-lora") FALLBACK_IMAGE_PROCESSOR = os.getenv("TROCR_FALLBACK_IMAGE_PROCESSOR", "google/vit-base-patch16-224-in21k") def get_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def image_to_bytes(image: Image.Image) -> bytes: buffer = io.BytesIO() image.convert("RGB").save(buffer, format="PNG") return buffer.getvalue() def bytes_to_cv2(image_bytes: bytes) -> np.ndarray: nparr = np.frombuffer(image_bytes, np.uint8) return cv2.imdecode(nparr, cv2.IMREAD_COLOR) def cv2_to_pil(img: np.ndarray) -> Image.Image: rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return Image.fromarray(rgb) def crop_to_foreground(img: np.ndarray, padding_ratio: float = 0.18) -> np.ndarray: gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img blurred = cv2.GaussianBlur(gray, (5, 5), 0) _, mask = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) kernel = np.ones((3, 3), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) mask = cv2.dilate(mask, kernel, iterations=1) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return img h, w = gray.shape[:2] min_area = max(12, int(h * w * 0.0001)) boxes = [cv2.boundingRect(contour) for contour in contours if cv2.contourArea(contour) >= min_area] if not boxes: return img x1 = min(x for x, _, _, _ in boxes) y1 = min(y for _, y, _, _ in boxes) x2 = max(x + bw for x, _, bw, _ in boxes) y2 = max(y + bh for _, y, _, bh in boxes) pad_x = max(8, int((x2 - x1) * padding_ratio)) pad_y = max(8, int((y2 - y1) * padding_ratio)) return img[max(0, y1 - pad_y):min(h, y2 + pad_y), max(0, x1 - pad_x):min(w, x2 + pad_x)] def normalize_for_model(img: np.ndarray, target_height: int = 384, target_width: int = 384) -> np.ndarray: h, w = img.shape[:2] scale = min(target_height / h, target_width / w) new_h = max(1, int(h * scale)) new_w = max(1, int(w * scale)) resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) if len(img.shape) == 3: canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255 else: canvas = np.ones((target_height, target_width), dtype=np.uint8) * 255 y_offset = (target_height - new_h) // 2 x_offset = (target_width - new_w) // 2 canvas[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized return canvas def preprocess(image: Image.Image, mode: str) -> Image.Image: image = image.convert("RGB") if mode == "Original": return image img = bytes_to_cv2(image_to_bytes(image)) if mode == "Foreground crop": return cv2_to_pil(crop_to_foreground(img)) if mode == "Square pad": return cv2_to_pil(normalize_for_model(img)) if mode == "Crop + square pad": return cv2_to_pil(normalize_for_model(crop_to_foreground(img))) return image def load_processor() -> TrOCRProcessor: try: return TrOCRProcessor.from_pretrained(BASE_MODEL) except Exception: try: image_processor = ViTImageProcessor.from_pretrained(ADAPTER_MODEL) except Exception: image_processor = ViTImageProcessor.from_pretrained(FALLBACK_IMAGE_PROCESSOR) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) return TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer) @lru_cache(maxsize=1) def load_model() -> tuple[VisionEncoderDecoderModel, TrOCRProcessor, str]: device = get_device() processor = load_processor() base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL) base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id base_model.config.pad_token_id = processor.tokenizer.pad_token_id base_model.config.eos_token_id = processor.tokenizer.sep_token_id base_model.config.vocab_size = base_model.config.decoder.vocab_size peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL) try: model = peft_model.merge_and_unload() except Exception: model = peft_model model.to(device) model.eval() return model, processor, device def recognize(image: Optional[Image.Image], preprocessing: str, max_length: int) -> tuple[str, Image.Image | None, dict]: if image is None: return "", None, {"error": "Upload an image first."} processed = preprocess(image, preprocessing) model, processor, device = load_model() started_at = time.perf_counter() pixel_values = processor(images=processed.convert("RGB"), return_tensors="pt").pixel_values.to(device) with torch.inference_mode(): outputs = model.generate( pixel_values, max_length=max_length, num_beams=4, return_dict_in_generate=True, output_scores=True, ) text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip() elapsed_ms = round((time.perf_counter() - started_at) * 1000, 2) details = { "base_model": BASE_MODEL, "adapter": ADAPTER_MODEL, "device": device, "preprocessing": preprocessing, "inference_ms": elapsed_ms, } return text, processed, details CSS = """ .gradio-container { max-width: 1120px !important; } #result_text textarea { font-size: 1.35rem; line-height: 1.8; } """ with gr.Blocks(title="DevGen Devanagari OCR") as demo: gr.Markdown( """ # DevGen Devanagari OCR Upload a Devanagari word or short line image. The demo runs a TrOCR base model with the DevGen LoRA adapter hosted on Hugging Face. """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Image") preprocessing_input = gr.Radio( ["Foreground crop", "Original", "Square pad", "Crop + square pad"], value="Foreground crop", label="Preprocessing", ) max_length_input = gr.Slider(16, 128, value=64, step=1, label="Max output length") submit = gr.Button("Recognize", variant="primary") with gr.Column(scale=1): text_output = gr.Textbox(label="Recognized text", lines=4, elem_id="result_text") processed_output = gr.Image(type="pil", label="Processed image") details_output = gr.JSON(label="Run details") submit.click( fn=recognize, inputs=[image_input, preprocessing_input, max_length_input], outputs=[text_output, processed_output, details_output], ) if __name__ == "__main__": demo.queue(max_size=8).launch(css=CSS, ssr_mode=False)