Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import io | |
| import os | |
| import time | |
| from functools import lru_cache | |
| from typing import Optional | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from peft import PeftModel | |
| from PIL import Image | |
| from transformers import AutoTokenizer, TrOCRProcessor, ViTImageProcessor, VisionEncoderDecoderModel | |
| BASE_MODEL = os.getenv("TROCR_BASE_MODEL", "paudelanil/trocr-devanagari-2") | |
| ADAPTER_MODEL = os.getenv("TROCR_ADAPTER_MODEL", "waglesameer5/devgen-trocr-devanagari-lora") | |
| FALLBACK_IMAGE_PROCESSOR = os.getenv("TROCR_FALLBACK_IMAGE_PROCESSOR", "google/vit-base-patch16-224-in21k") | |
| def get_device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def image_to_bytes(image: Image.Image) -> bytes: | |
| buffer = io.BytesIO() | |
| image.convert("RGB").save(buffer, format="PNG") | |
| return buffer.getvalue() | |
| def bytes_to_cv2(image_bytes: bytes) -> np.ndarray: | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| return cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| def cv2_to_pil(img: np.ndarray) -> Image.Image: | |
| rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(rgb) | |
| def crop_to_foreground(img: np.ndarray, padding_ratio: float = 0.18) -> np.ndarray: | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| _, mask = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
| kernel = np.ones((3, 3), np.uint8) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| mask = cv2.dilate(mask, kernel, iterations=1) | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return img | |
| h, w = gray.shape[:2] | |
| min_area = max(12, int(h * w * 0.0001)) | |
| boxes = [cv2.boundingRect(contour) for contour in contours if cv2.contourArea(contour) >= min_area] | |
| if not boxes: | |
| return img | |
| x1 = min(x for x, _, _, _ in boxes) | |
| y1 = min(y for _, y, _, _ in boxes) | |
| x2 = max(x + bw for x, _, bw, _ in boxes) | |
| y2 = max(y + bh for _, y, _, bh in boxes) | |
| pad_x = max(8, int((x2 - x1) * padding_ratio)) | |
| pad_y = max(8, int((y2 - y1) * padding_ratio)) | |
| return img[max(0, y1 - pad_y):min(h, y2 + pad_y), max(0, x1 - pad_x):min(w, x2 + pad_x)] | |
| def normalize_for_model(img: np.ndarray, target_height: int = 384, target_width: int = 384) -> np.ndarray: | |
| h, w = img.shape[:2] | |
| scale = min(target_height / h, target_width / w) | |
| new_h = max(1, int(h * scale)) | |
| new_w = max(1, int(w * scale)) | |
| resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| if len(img.shape) == 3: | |
| canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255 | |
| else: | |
| canvas = np.ones((target_height, target_width), dtype=np.uint8) * 255 | |
| y_offset = (target_height - new_h) // 2 | |
| x_offset = (target_width - new_w) // 2 | |
| canvas[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized | |
| return canvas | |
| def preprocess(image: Image.Image, mode: str) -> Image.Image: | |
| image = image.convert("RGB") | |
| if mode == "Original": | |
| return image | |
| img = bytes_to_cv2(image_to_bytes(image)) | |
| if mode == "Foreground crop": | |
| return cv2_to_pil(crop_to_foreground(img)) | |
| if mode == "Square pad": | |
| return cv2_to_pil(normalize_for_model(img)) | |
| if mode == "Crop + square pad": | |
| return cv2_to_pil(normalize_for_model(crop_to_foreground(img))) | |
| return image | |
| def load_processor() -> TrOCRProcessor: | |
| try: | |
| return TrOCRProcessor.from_pretrained(BASE_MODEL) | |
| except Exception: | |
| try: | |
| image_processor = ViTImageProcessor.from_pretrained(ADAPTER_MODEL) | |
| except Exception: | |
| image_processor = ViTImageProcessor.from_pretrained(FALLBACK_IMAGE_PROCESSOR) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| return TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer) | |
| def load_model() -> tuple[VisionEncoderDecoderModel, TrOCRProcessor, str]: | |
| device = get_device() | |
| processor = load_processor() | |
| base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL) | |
| base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id | |
| base_model.config.pad_token_id = processor.tokenizer.pad_token_id | |
| base_model.config.eos_token_id = processor.tokenizer.sep_token_id | |
| base_model.config.vocab_size = base_model.config.decoder.vocab_size | |
| peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL) | |
| try: | |
| model = peft_model.merge_and_unload() | |
| except Exception: | |
| model = peft_model | |
| model.to(device) | |
| model.eval() | |
| return model, processor, device | |
| def recognize(image: Optional[Image.Image], preprocessing: str, max_length: int) -> tuple[str, Image.Image | None, dict]: | |
| if image is None: | |
| return "", None, {"error": "Upload an image first."} | |
| processed = preprocess(image, preprocessing) | |
| model, processor, device = load_model() | |
| started_at = time.perf_counter() | |
| pixel_values = processor(images=processed.convert("RGB"), return_tensors="pt").pixel_values.to(device) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| pixel_values, | |
| max_length=max_length, | |
| num_beams=4, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip() | |
| elapsed_ms = round((time.perf_counter() - started_at) * 1000, 2) | |
| details = { | |
| "base_model": BASE_MODEL, | |
| "adapter": ADAPTER_MODEL, | |
| "device": device, | |
| "preprocessing": preprocessing, | |
| "inference_ms": elapsed_ms, | |
| } | |
| return text, processed, details | |
| CSS = """ | |
| .gradio-container { max-width: 1120px !important; } | |
| #result_text textarea { font-size: 1.35rem; line-height: 1.8; } | |
| """ | |
| with gr.Blocks(title="DevGen Devanagari OCR") as demo: | |
| gr.Markdown( | |
| """ | |
| # DevGen Devanagari OCR | |
| Upload a Devanagari word or short line image. The demo runs a TrOCR base model with the DevGen LoRA adapter hosted on Hugging Face. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Image") | |
| preprocessing_input = gr.Radio( | |
| ["Foreground crop", "Original", "Square pad", "Crop + square pad"], | |
| value="Foreground crop", | |
| label="Preprocessing", | |
| ) | |
| max_length_input = gr.Slider(16, 128, value=64, step=1, label="Max output length") | |
| submit = gr.Button("Recognize", variant="primary") | |
| with gr.Column(scale=1): | |
| text_output = gr.Textbox(label="Recognized text", lines=4, elem_id="result_text") | |
| processed_output = gr.Image(type="pil", label="Processed image") | |
| details_output = gr.JSON(label="Run details") | |
| submit.click( | |
| fn=recognize, | |
| inputs=[image_input, preprocessing_input, max_length_input], | |
| outputs=[text_output, processed_output, details_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch(css=CSS, ssr_mode=False) | |