waglesameer5's picture
Clean Gradio 6 launch options
2cd3b3d verified
from __future__ import annotations
import io
import os
import time
from functools import lru_cache
from typing import Optional
import cv2
import gradio as gr
import numpy as np
import torch
from peft import PeftModel
from PIL import Image
from transformers import AutoTokenizer, TrOCRProcessor, ViTImageProcessor, VisionEncoderDecoderModel
BASE_MODEL = os.getenv("TROCR_BASE_MODEL", "paudelanil/trocr-devanagari-2")
ADAPTER_MODEL = os.getenv("TROCR_ADAPTER_MODEL", "waglesameer5/devgen-trocr-devanagari-lora")
FALLBACK_IMAGE_PROCESSOR = os.getenv("TROCR_FALLBACK_IMAGE_PROCESSOR", "google/vit-base-patch16-224-in21k")
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def image_to_bytes(image: Image.Image) -> bytes:
buffer = io.BytesIO()
image.convert("RGB").save(buffer, format="PNG")
return buffer.getvalue()
def bytes_to_cv2(image_bytes: bytes) -> np.ndarray:
nparr = np.frombuffer(image_bytes, np.uint8)
return cv2.imdecode(nparr, cv2.IMREAD_COLOR)
def cv2_to_pil(img: np.ndarray) -> Image.Image:
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return Image.fromarray(rgb)
def crop_to_foreground(img: np.ndarray, padding_ratio: float = 0.18) -> np.ndarray:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
_, mask = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
kernel = np.ones((3, 3), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
mask = cv2.dilate(mask, kernel, iterations=1)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return img
h, w = gray.shape[:2]
min_area = max(12, int(h * w * 0.0001))
boxes = [cv2.boundingRect(contour) for contour in contours if cv2.contourArea(contour) >= min_area]
if not boxes:
return img
x1 = min(x for x, _, _, _ in boxes)
y1 = min(y for _, y, _, _ in boxes)
x2 = max(x + bw for x, _, bw, _ in boxes)
y2 = max(y + bh for _, y, _, bh in boxes)
pad_x = max(8, int((x2 - x1) * padding_ratio))
pad_y = max(8, int((y2 - y1) * padding_ratio))
return img[max(0, y1 - pad_y):min(h, y2 + pad_y), max(0, x1 - pad_x):min(w, x2 + pad_x)]
def normalize_for_model(img: np.ndarray, target_height: int = 384, target_width: int = 384) -> np.ndarray:
h, w = img.shape[:2]
scale = min(target_height / h, target_width / w)
new_h = max(1, int(h * scale))
new_w = max(1, int(w * scale))
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
if len(img.shape) == 3:
canvas = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255
else:
canvas = np.ones((target_height, target_width), dtype=np.uint8) * 255
y_offset = (target_height - new_h) // 2
x_offset = (target_width - new_w) // 2
canvas[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized
return canvas
def preprocess(image: Image.Image, mode: str) -> Image.Image:
image = image.convert("RGB")
if mode == "Original":
return image
img = bytes_to_cv2(image_to_bytes(image))
if mode == "Foreground crop":
return cv2_to_pil(crop_to_foreground(img))
if mode == "Square pad":
return cv2_to_pil(normalize_for_model(img))
if mode == "Crop + square pad":
return cv2_to_pil(normalize_for_model(crop_to_foreground(img)))
return image
def load_processor() -> TrOCRProcessor:
try:
return TrOCRProcessor.from_pretrained(BASE_MODEL)
except Exception:
try:
image_processor = ViTImageProcessor.from_pretrained(ADAPTER_MODEL)
except Exception:
image_processor = ViTImageProcessor.from_pretrained(FALLBACK_IMAGE_PROCESSOR)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
return TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
@lru_cache(maxsize=1)
def load_model() -> tuple[VisionEncoderDecoderModel, TrOCRProcessor, str]:
device = get_device()
processor = load_processor()
base_model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL)
base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
base_model.config.pad_token_id = processor.tokenizer.pad_token_id
base_model.config.eos_token_id = processor.tokenizer.sep_token_id
base_model.config.vocab_size = base_model.config.decoder.vocab_size
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
try:
model = peft_model.merge_and_unload()
except Exception:
model = peft_model
model.to(device)
model.eval()
return model, processor, device
def recognize(image: Optional[Image.Image], preprocessing: str, max_length: int) -> tuple[str, Image.Image | None, dict]:
if image is None:
return "", None, {"error": "Upload an image first."}
processed = preprocess(image, preprocessing)
model, processor, device = load_model()
started_at = time.perf_counter()
pixel_values = processor(images=processed.convert("RGB"), return_tensors="pt").pixel_values.to(device)
with torch.inference_mode():
outputs = model.generate(
pixel_values,
max_length=max_length,
num_beams=4,
return_dict_in_generate=True,
output_scores=True,
)
text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()
elapsed_ms = round((time.perf_counter() - started_at) * 1000, 2)
details = {
"base_model": BASE_MODEL,
"adapter": ADAPTER_MODEL,
"device": device,
"preprocessing": preprocessing,
"inference_ms": elapsed_ms,
}
return text, processed, details
CSS = """
.gradio-container { max-width: 1120px !important; }
#result_text textarea { font-size: 1.35rem; line-height: 1.8; }
"""
with gr.Blocks(title="DevGen Devanagari OCR") as demo:
gr.Markdown(
"""
# DevGen Devanagari OCR
Upload a Devanagari word or short line image. The demo runs a TrOCR base model with the DevGen LoRA adapter hosted on Hugging Face.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Image")
preprocessing_input = gr.Radio(
["Foreground crop", "Original", "Square pad", "Crop + square pad"],
value="Foreground crop",
label="Preprocessing",
)
max_length_input = gr.Slider(16, 128, value=64, step=1, label="Max output length")
submit = gr.Button("Recognize", variant="primary")
with gr.Column(scale=1):
text_output = gr.Textbox(label="Recognized text", lines=4, elem_id="result_text")
processed_output = gr.Image(type="pil", label="Processed image")
details_output = gr.JSON(label="Run details")
submit.click(
fn=recognize,
inputs=[image_input, preprocessing_input, max_length_input],
outputs=[text_output, processed_output, details_output],
)
if __name__ == "__main__":
demo.queue(max_size=8).launch(css=CSS, ssr_mode=False)