Spaces:
Running on Zero
Running on Zero
| """Braille Reader — Upload a braille image, get English text.""" | |
| import json | |
| import os | |
| import tempfile | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from huggingface_hub import CommitScheduler, hf_hub_download | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from ultralytics import YOLO | |
| # --- Model loading (on CPU at startup, GPU allocated per-request) --- | |
| YOLO_REPO = "prasanthmj/yolov8-braille" | |
| BYT5_REPO = "prasanthmj/braille-byt5-v3" | |
| DATASET_REPO = "prasanthmj/braille-reader-results" | |
| print("Loading models...") | |
| # YOLOv8 braille detector | |
| weights_path = hf_hub_download(YOLO_REPO, "yolov8_braille.pt") | |
| braille_map_path = hf_hub_download(YOLO_REPO, "braille_map.json") | |
| yolo_model = YOLO(weights_path) | |
| with open(braille_map_path) as f: | |
| dot_to_unicode = json.load(f) | |
| # ByT5 Grade 2 interpreter (load on CPU, moved to GPU per-request) | |
| tokenizer = AutoTokenizer.from_pretrained(BYT5_REPO) | |
| byt5_model = AutoModelForSeq2SeqLM.from_pretrained(BYT5_REPO) | |
| byt5_model.eval() | |
| print("Models loaded (CPU). GPU allocated per-request via ZeroGPU.") | |
| # --- Result saving via CommitScheduler --- | |
| RESULTS_DIR = Path("./results") | |
| RESULTS_DIR.mkdir(exist_ok=True) | |
| (RESULTS_DIR / "images").mkdir(exist_ok=True) | |
| scheduler = CommitScheduler( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| folder_path=RESULTS_DIR, | |
| every=5, # push every 5 minutes | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| def save_result(image: np.ndarray, braille_text: str, english_text: str, | |
| total_cells: int, num_lines: int, avg_conf: float): | |
| """Save image and result to the dataset (batched by CommitScheduler).""" | |
| entry_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6] | |
| # Save image | |
| image_filename = f"images/{entry_id}.jpg" | |
| image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(str(RESULTS_DIR / image_filename), image_bgr) | |
| # Append to JSONL | |
| record = { | |
| "id": entry_id, | |
| "image": image_filename, | |
| "braille_unicode": braille_text, | |
| "english": english_text, | |
| "cells": total_cells, | |
| "lines": num_lines, | |
| "avg_confidence": round(avg_conf, 4), | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| with scheduler.lock: | |
| with open(RESULTS_DIR / "results.jsonl", "a") as f: | |
| f.write(json.dumps(record) + "\n") | |
| # --- CLAHE Preprocessing --- | |
| def preprocess_clahe(image_path: str) -> str: | |
| """Apply CLAHE preprocessing for better detection on low-contrast images.""" | |
| img = cv2.imread(image_path) | |
| if img is None: | |
| return image_path | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| enhanced = clahe.apply(gray) | |
| enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR) | |
| tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) | |
| cv2.imwrite(tmp.name, enhanced_bgr) | |
| return tmp.name | |
| # --- Stage 1: YOLOv8 Detection --- | |
| def detect_braille(image_path: str, confidence: float = 0.15) -> list[list[dict]]: | |
| """Detect braille cells and group into lines.""" | |
| results = yolo_model.predict(image_path, conf=confidence, verbose=False) | |
| boxes = results[0].boxes | |
| if len(boxes) == 0: | |
| return [] | |
| n = len(boxes) | |
| data = np.zeros((n, 6)) | |
| data[:, 0] = boxes.xywh[:, 0].cpu().numpy() | |
| data[:, 1] = boxes.xywh[:, 1].cpu().numpy() | |
| data[:, 2] = boxes.xywh[:, 2].cpu().numpy() | |
| data[:, 3] = boxes.xywh[:, 3].cpu().numpy() | |
| data[:, 4] = boxes.conf.cpu().numpy() | |
| data[:, 5] = boxes.cls.cpu().numpy() | |
| # Sort by Y | |
| data = data[data[:, 1].argsort()] | |
| # Split into lines by Y gaps | |
| avg_height = np.mean(data[:, 3]) | |
| y_threshold = avg_height / 2 | |
| y_diffs = np.diff(data[:, 1]) | |
| break_indices = np.where(y_diffs > y_threshold)[0] | |
| raw_lines = np.split(data, break_indices + 1) | |
| lines = [] | |
| for raw_line in raw_lines: | |
| raw_line = raw_line[raw_line[:, 0].argsort()] | |
| cells = [] | |
| for row in raw_line: | |
| class_idx = int(row[5]) | |
| dots = yolo_model.names[class_idx] | |
| unicode_char = dot_to_unicode.get(dots, "?") | |
| cells.append({ | |
| "dots": dots, | |
| "unicode": unicode_char, | |
| "confidence": row[4], | |
| }) | |
| lines.append(cells) | |
| return lines | |
| # --- Main pipeline (GPU allocated here) --- | |
| def transcribe(image) -> str: | |
| """Full pipeline: image -> detection -> interpretation -> English text.""" | |
| if image is None: | |
| return "Please upload an image." | |
| # Save uploaded image to temp file | |
| tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) | |
| if isinstance(image, np.ndarray): | |
| cv2.imwrite(tmp.name, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
| else: | |
| cv2.imwrite(tmp.name, image) | |
| image_path = tmp.name | |
| # CLAHE preprocessing | |
| processed_path = preprocess_clahe(image_path) | |
| # Stage 1: Detect braille cells | |
| lines = detect_braille(processed_path) | |
| if not lines: | |
| return "No braille cells detected. Try a clearer image." | |
| # Extract Unicode braille per line | |
| braille_lines = ["".join(cell["unicode"] for cell in line) for line in lines] | |
| # Stats | |
| total_cells = sum(len(line) for line in lines) | |
| avg_conf = float(np.mean([cell["confidence"] for line in lines for cell in line])) | |
| # Stage 2: Interpret each line with ByT5 on GPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| byt5_model.to(device) | |
| english_lines = [] | |
| for line in braille_lines: | |
| if not line.strip(): | |
| english_lines.append("") | |
| continue | |
| input_text = f"translate Braille to English: {line}" | |
| inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = byt5_model.generate(**inputs, max_length=512) | |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| english_lines.append(decoded) | |
| # Format output | |
| braille_text = "\n".join(braille_lines) | |
| english_text = "\n".join(english_lines) | |
| # Save to dataset | |
| save_result(image, braille_text, english_text, total_cells, len(lines), avg_conf) | |
| output = f"{english_text}\n\n" | |
| output += f"--- Details ---\n" | |
| output += f"Cells detected: {total_cells}\n" | |
| output += f"Lines: {len(lines)}\n" | |
| output += f"Avg confidence: {avg_conf:.1%}\n" | |
| output += f"\nBraille Unicode:\n{braille_text}" | |
| return output | |
| # --- Gradio UI --- | |
| demo = gr.Interface( | |
| fn=transcribe, | |
| inputs=gr.Image(type="numpy", label="Upload Braille Image"), | |
| outputs=gr.Textbox(label="English Translation", lines=15), | |
| title="Braille Reader", | |
| description="Upload a scanned braille document to get its English translation. Supports Grade 2 (contracted) braille.", | |
| examples=[], | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |