"""Braille Reader — Upload a braille image, get English text.""" import json import os import tempfile import uuid from datetime import datetime from pathlib import Path import cv2 import gradio as gr import numpy as np import spaces import torch from huggingface_hub import CommitScheduler, hf_hub_download from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from ultralytics import YOLO # --- Model loading (on CPU at startup, GPU allocated per-request) --- YOLO_REPO = "prasanthmj/yolov8-braille" BYT5_REPO = "prasanthmj/braille-byt5-v3" DATASET_REPO = "prasanthmj/braille-reader-results" print("Loading models...") # YOLOv8 braille detector weights_path = hf_hub_download(YOLO_REPO, "yolov8_braille.pt") braille_map_path = hf_hub_download(YOLO_REPO, "braille_map.json") yolo_model = YOLO(weights_path) with open(braille_map_path) as f: dot_to_unicode = json.load(f) # ByT5 Grade 2 interpreter (load on CPU, moved to GPU per-request) tokenizer = AutoTokenizer.from_pretrained(BYT5_REPO) byt5_model = AutoModelForSeq2SeqLM.from_pretrained(BYT5_REPO) byt5_model.eval() print("Models loaded (CPU). GPU allocated per-request via ZeroGPU.") # --- Result saving via CommitScheduler --- RESULTS_DIR = Path("./results") RESULTS_DIR.mkdir(exist_ok=True) (RESULTS_DIR / "images").mkdir(exist_ok=True) scheduler = CommitScheduler( repo_id=DATASET_REPO, repo_type="dataset", folder_path=RESULTS_DIR, every=5, # push every 5 minutes token=os.environ.get("HF_TOKEN"), ) def save_result(image: np.ndarray, braille_text: str, english_text: str, total_cells: int, num_lines: int, avg_conf: float): """Save image and result to the dataset (batched by CommitScheduler).""" entry_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6] # Save image image_filename = f"images/{entry_id}.jpg" image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.imwrite(str(RESULTS_DIR / image_filename), image_bgr) # Append to JSONL record = { "id": entry_id, "image": image_filename, "braille_unicode": braille_text, "english": english_text, "cells": total_cells, "lines": num_lines, "avg_confidence": round(avg_conf, 4), "timestamp": datetime.utcnow().isoformat(), } with scheduler.lock: with open(RESULTS_DIR / "results.jsonl", "a") as f: f.write(json.dumps(record) + "\n") # --- CLAHE Preprocessing --- def preprocess_clahe(image_path: str) -> str: """Apply CLAHE preprocessing for better detection on low-contrast images.""" img = cv2.imread(image_path) if img is None: return image_path gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) enhanced = clahe.apply(gray) enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR) tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) cv2.imwrite(tmp.name, enhanced_bgr) return tmp.name # --- Stage 1: YOLOv8 Detection --- def detect_braille(image_path: str, confidence: float = 0.15) -> list[list[dict]]: """Detect braille cells and group into lines.""" results = yolo_model.predict(image_path, conf=confidence, verbose=False) boxes = results[0].boxes if len(boxes) == 0: return [] n = len(boxes) data = np.zeros((n, 6)) data[:, 0] = boxes.xywh[:, 0].cpu().numpy() data[:, 1] = boxes.xywh[:, 1].cpu().numpy() data[:, 2] = boxes.xywh[:, 2].cpu().numpy() data[:, 3] = boxes.xywh[:, 3].cpu().numpy() data[:, 4] = boxes.conf.cpu().numpy() data[:, 5] = boxes.cls.cpu().numpy() # Sort by Y data = data[data[:, 1].argsort()] # Split into lines by Y gaps avg_height = np.mean(data[:, 3]) y_threshold = avg_height / 2 y_diffs = np.diff(data[:, 1]) break_indices = np.where(y_diffs > y_threshold)[0] raw_lines = np.split(data, break_indices + 1) lines = [] for raw_line in raw_lines: raw_line = raw_line[raw_line[:, 0].argsort()] cells = [] for row in raw_line: class_idx = int(row[5]) dots = yolo_model.names[class_idx] unicode_char = dot_to_unicode.get(dots, "?") cells.append({ "dots": dots, "unicode": unicode_char, "confidence": row[4], }) lines.append(cells) return lines # --- Main pipeline (GPU allocated here) --- @spaces.GPU def transcribe(image) -> str: """Full pipeline: image -> detection -> interpretation -> English text.""" if image is None: return "Please upload an image." # Save uploaded image to temp file tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) if isinstance(image, np.ndarray): cv2.imwrite(tmp.name, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) else: cv2.imwrite(tmp.name, image) image_path = tmp.name # CLAHE preprocessing processed_path = preprocess_clahe(image_path) # Stage 1: Detect braille cells lines = detect_braille(processed_path) if not lines: return "No braille cells detected. Try a clearer image." # Extract Unicode braille per line braille_lines = ["".join(cell["unicode"] for cell in line) for line in lines] # Stats total_cells = sum(len(line) for line in lines) avg_conf = float(np.mean([cell["confidence"] for line in lines for cell in line])) # Stage 2: Interpret each line with ByT5 on GPU device = "cuda" if torch.cuda.is_available() else "cpu" byt5_model.to(device) english_lines = [] for line in braille_lines: if not line.strip(): english_lines.append("") continue input_text = f"translate Braille to English: {line}" inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = byt5_model.generate(**inputs, max_length=512) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) english_lines.append(decoded) # Format output braille_text = "\n".join(braille_lines) english_text = "\n".join(english_lines) # Save to dataset save_result(image, braille_text, english_text, total_cells, len(lines), avg_conf) output = f"{english_text}\n\n" output += f"--- Details ---\n" output += f"Cells detected: {total_cells}\n" output += f"Lines: {len(lines)}\n" output += f"Avg confidence: {avg_conf:.1%}\n" output += f"\nBraille Unicode:\n{braille_text}" return output # --- Gradio UI --- demo = gr.Interface( fn=transcribe, inputs=gr.Image(type="numpy", label="Upload Braille Image"), outputs=gr.Textbox(label="English Translation", lines=15), title="Braille Reader", description="Upload a scanned braille document to get its English translation. Supports Grade 2 (contracted) braille.", examples=[], flagging_mode="never", ) if __name__ == "__main__": demo.launch()