prasanthmj commited on
Commit
99857c5
·
verified ·
1 Parent(s): 44614a0

Initial app: full braille OCR pipeline (YOLOv8 + ByT5)

Browse files
Files changed (3) hide show
  1. README.md +16 -5
  2. app.py +187 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,12 +1,23 @@
1
  ---
2
  title: Braille Reader
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
  ---
2
  title: Braille Reader
3
+ emoji: 👁️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: "5.23.0"
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ models:
12
+ - prasanthmj/braille-byt5-v3
13
+ - prasanthmj/yolov8-braille
14
  ---
15
 
16
+ # Braille Reader
17
+
18
+ Upload a scanned braille document and get its English translation.
19
+
20
+ - **Stage 1:** YOLOv8 detects braille cells in the image
21
+ - **Stage 2:** ByT5 translates detected braille (Grade 2 contracted) to English
22
+
23
+ Supports Grade 2 (contracted) braille — the form used in 90-95% of real braille documents.
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Braille Reader — Upload a braille image, get English text."""
2
+
3
+ import json
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
13
+ from ultralytics import YOLO
14
+
15
+ # --- Model loading (cached at startup) ---
16
+
17
+ YOLO_REPO = "prasanthmj/yolov8-braille"
18
+ BYT5_REPO = "prasanthmj/braille-byt5-v3"
19
+
20
+ def load_models():
21
+ """Download and load both models."""
22
+ # YOLOv8 braille detector
23
+ weights_path = hf_hub_download(YOLO_REPO, "yolov8_braille.pt")
24
+ braille_map_path = hf_hub_download(YOLO_REPO, "braille_map.json")
25
+
26
+ yolo_model = YOLO(weights_path)
27
+ with open(braille_map_path) as f:
28
+ dot_to_unicode = json.load(f)
29
+
30
+ # ByT5 Grade 2 interpreter
31
+ tokenizer = AutoTokenizer.from_pretrained(BYT5_REPO)
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ byt5_model = AutoModelForSeq2SeqLM.from_pretrained(BYT5_REPO).to(device)
34
+ byt5_model.eval()
35
+
36
+ return yolo_model, dot_to_unicode, tokenizer, byt5_model, device
37
+
38
+
39
+ print("Loading models...")
40
+ yolo_model, dot_to_unicode, tokenizer, byt5_model, device = load_models()
41
+ print(f"Models loaded. Device: {device}")
42
+
43
+ # --- CLAHE Preprocessing ---
44
+
45
+ def preprocess_clahe(image_path: str) -> str:
46
+ """Apply CLAHE preprocessing for better detection on low-contrast images."""
47
+ img = cv2.imread(image_path)
48
+ if img is None:
49
+ return image_path
50
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
51
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
52
+ enhanced = clahe.apply(gray)
53
+ enhanced_bgr = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
54
+
55
+ tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
56
+ cv2.imwrite(tmp.name, enhanced_bgr)
57
+ return tmp.name
58
+
59
+ # --- Stage 1: YOLOv8 Detection ---
60
+
61
+ def detect_braille(image_path: str, confidence: float = 0.15) -> list[list[dict]]:
62
+ """Detect braille cells and group into lines."""
63
+ results = yolo_model.predict(image_path, conf=confidence, verbose=False)
64
+ boxes = results[0].boxes
65
+
66
+ if len(boxes) == 0:
67
+ return []
68
+
69
+ n = len(boxes)
70
+ data = np.zeros((n, 6))
71
+ data[:, 0] = boxes.xywh[:, 0].cpu().numpy()
72
+ data[:, 1] = boxes.xywh[:, 1].cpu().numpy()
73
+ data[:, 2] = boxes.xywh[:, 2].cpu().numpy()
74
+ data[:, 3] = boxes.xywh[:, 3].cpu().numpy()
75
+ data[:, 4] = boxes.conf.cpu().numpy()
76
+ data[:, 5] = boxes.cls.cpu().numpy()
77
+
78
+ # Sort by Y
79
+ data = data[data[:, 1].argsort()]
80
+
81
+ # Split into lines by Y gaps
82
+ avg_height = np.mean(data[:, 3])
83
+ y_threshold = avg_height / 2
84
+ y_diffs = np.diff(data[:, 1])
85
+ break_indices = np.where(y_diffs > y_threshold)[0]
86
+ raw_lines = np.split(data, break_indices + 1)
87
+
88
+ lines = []
89
+ for raw_line in raw_lines:
90
+ raw_line = raw_line[raw_line[:, 0].argsort()]
91
+ cells = []
92
+ for row in raw_line:
93
+ class_idx = int(row[5])
94
+ dots = yolo_model.names[class_idx]
95
+ unicode_char = dot_to_unicode.get(dots, "?")
96
+ cells.append({
97
+ "dots": dots,
98
+ "unicode": unicode_char,
99
+ "confidence": row[4],
100
+ })
101
+ lines.append(cells)
102
+
103
+ return lines
104
+
105
+ # --- Stage 2: ByT5 Interpretation ---
106
+
107
+ def interpret_braille(braille_lines: list[str]) -> list[str]:
108
+ """Translate braille Unicode lines to English using ByT5."""
109
+ results = []
110
+ for line in braille_lines:
111
+ if not line.strip():
112
+ results.append("")
113
+ continue
114
+
115
+ input_text = f"translate Braille to English: {line}"
116
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True)
117
+ inputs = {k: v.to(device) for k, v in inputs.items()}
118
+
119
+ with torch.no_grad():
120
+ outputs = byt5_model.generate(**inputs, max_length=512)
121
+
122
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
123
+ results.append(decoded)
124
+
125
+ return results
126
+
127
+ # --- Main pipeline ---
128
+
129
+ def transcribe(image) -> str:
130
+ """Full pipeline: image -> detection -> interpretation -> English text."""
131
+ if image is None:
132
+ return "Please upload an image."
133
+
134
+ # Save uploaded image to temp file
135
+ tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
136
+ if isinstance(image, np.ndarray):
137
+ cv2.imwrite(tmp.name, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
138
+ else:
139
+ cv2.imwrite(tmp.name, image)
140
+ image_path = tmp.name
141
+
142
+ # CLAHE preprocessing
143
+ processed_path = preprocess_clahe(image_path)
144
+
145
+ # Stage 1: Detect braille cells
146
+ lines = detect_braille(processed_path)
147
+
148
+ if not lines:
149
+ return "No braille cells detected. Try a clearer image."
150
+
151
+ # Extract Unicode braille per line
152
+ braille_lines = ["".join(cell["unicode"] for cell in line) for line in lines]
153
+
154
+ # Stats
155
+ total_cells = sum(len(line) for line in lines)
156
+ avg_conf = np.mean([cell["confidence"] for line in lines for cell in line])
157
+
158
+ # Stage 2: Interpret with ByT5
159
+ english_lines = interpret_braille(braille_lines)
160
+
161
+ # Format output
162
+ braille_text = "\n".join(braille_lines)
163
+ english_text = "\n".join(english_lines)
164
+
165
+ output = f"{english_text}\n\n"
166
+ output += f"--- Details ---\n"
167
+ output += f"Cells detected: {total_cells}\n"
168
+ output += f"Lines: {len(lines)}\n"
169
+ output += f"Avg confidence: {avg_conf:.1%}\n"
170
+ output += f"\nBraille Unicode:\n{braille_text}"
171
+
172
+ return output
173
+
174
+ # --- Gradio UI ---
175
+
176
+ demo = gr.Interface(
177
+ fn=transcribe,
178
+ inputs=gr.Image(type="numpy", label="Upload Braille Image"),
179
+ outputs=gr.Textbox(label="English Translation", lines=15),
180
+ title="Braille Reader",
181
+ description="Upload a scanned braille document to get its English translation. Supports Grade 2 (contracted) braille.",
182
+ examples=[],
183
+ flagging_mode="never",
184
+ )
185
+
186
+ if __name__ == "__main__":
187
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ultralytics>=8.0.0
2
+ opencv-python-headless>=4.0.0
3
+ torch>=2.0.0
4
+ transformers>=4.30.0
5
+ sentencepiece>=0.1.99
6
+ huggingface_hub>=0.20.0
7
+ numpy>=1.23.0