prasanthmj commited on
Commit
694bc64
·
verified ·
1 Parent(s): 6e08ed4

Save each result (image + braille + english) to braille-reader-results dataset

Browse files
Files changed (1) hide show
  1. app.py +51 -2
app.py CHANGED
@@ -1,14 +1,18 @@
1
  """Braille Reader — Upload a braille image, get English text."""
2
 
3
  import json
 
4
  import tempfile
 
 
 
5
 
6
  import cv2
7
  import gradio as gr
8
  import numpy as np
9
  import spaces
10
  import torch
11
- from huggingface_hub import hf_hub_download
12
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
13
  from ultralytics import YOLO
14
 
@@ -16,6 +20,7 @@ from ultralytics import YOLO
16
 
17
  YOLO_REPO = "prasanthmj/yolov8-braille"
18
  BYT5_REPO = "prasanthmj/braille-byt5-v3"
 
19
 
20
  print("Loading models...")
21
 
@@ -33,6 +38,47 @@ byt5_model.eval()
33
 
34
  print("Models loaded (CPU). GPU allocated per-request via ZeroGPU.")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # --- CLAHE Preprocessing ---
37
 
38
  def preprocess_clahe(image_path: str) -> str:
@@ -125,7 +171,7 @@ def transcribe(image) -> str:
125
 
126
  # Stats
127
  total_cells = sum(len(line) for line in lines)
128
- avg_conf = np.mean([cell["confidence"] for line in lines for cell in line])
129
 
130
  # Stage 2: Interpret each line with ByT5 on GPU
131
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -151,6 +197,9 @@ def transcribe(image) -> str:
151
  braille_text = "\n".join(braille_lines)
152
  english_text = "\n".join(english_lines)
153
 
 
 
 
154
  output = f"{english_text}\n\n"
155
  output += f"--- Details ---\n"
156
  output += f"Cells detected: {total_cells}\n"
 
1
  """Braille Reader — Upload a braille image, get English text."""
2
 
3
  import json
4
+ import os
5
  import tempfile
6
+ import uuid
7
+ from datetime import datetime
8
+ from pathlib import Path
9
 
10
  import cv2
11
  import gradio as gr
12
  import numpy as np
13
  import spaces
14
  import torch
15
+ from huggingface_hub import CommitScheduler, hf_hub_download
16
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
17
  from ultralytics import YOLO
18
 
 
20
 
21
  YOLO_REPO = "prasanthmj/yolov8-braille"
22
  BYT5_REPO = "prasanthmj/braille-byt5-v3"
23
+ DATASET_REPO = "prasanthmj/braille-reader-results"
24
 
25
  print("Loading models...")
26
 
 
38
 
39
  print("Models loaded (CPU). GPU allocated per-request via ZeroGPU.")
40
 
41
+ # --- Result saving via CommitScheduler ---
42
+
43
+ RESULTS_DIR = Path("./results")
44
+ RESULTS_DIR.mkdir(exist_ok=True)
45
+ (RESULTS_DIR / "images").mkdir(exist_ok=True)
46
+
47
+ scheduler = CommitScheduler(
48
+ repo_id=DATASET_REPO,
49
+ repo_type="dataset",
50
+ folder_path=RESULTS_DIR,
51
+ every=5, # push every 5 minutes
52
+ token=os.environ.get("HF_TOKEN"),
53
+ )
54
+
55
+
56
+ def save_result(image: np.ndarray, braille_text: str, english_text: str,
57
+ total_cells: int, num_lines: int, avg_conf: float):
58
+ """Save image and result to the dataset (batched by CommitScheduler)."""
59
+ entry_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6]
60
+
61
+ # Save image
62
+ image_filename = f"images/{entry_id}.jpg"
63
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
64
+ cv2.imwrite(str(RESULTS_DIR / image_filename), image_bgr)
65
+
66
+ # Append to JSONL
67
+ record = {
68
+ "id": entry_id,
69
+ "image": image_filename,
70
+ "braille_unicode": braille_text,
71
+ "english": english_text,
72
+ "cells": total_cells,
73
+ "lines": num_lines,
74
+ "avg_confidence": round(avg_conf, 4),
75
+ "timestamp": datetime.utcnow().isoformat(),
76
+ }
77
+
78
+ with scheduler.lock:
79
+ with open(RESULTS_DIR / "results.jsonl", "a") as f:
80
+ f.write(json.dumps(record) + "\n")
81
+
82
  # --- CLAHE Preprocessing ---
83
 
84
  def preprocess_clahe(image_path: str) -> str:
 
171
 
172
  # Stats
173
  total_cells = sum(len(line) for line in lines)
174
+ avg_conf = float(np.mean([cell["confidence"] for line in lines for cell in line]))
175
 
176
  # Stage 2: Interpret each line with ByT5 on GPU
177
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
197
  braille_text = "\n".join(braille_lines)
198
  english_text = "\n".join(english_lines)
199
 
200
+ # Save to dataset
201
+ save_result(image, braille_text, english_text, total_cells, len(lines), avg_conf)
202
+
203
  output = f"{english_text}\n\n"
204
  output += f"--- Details ---\n"
205
  output += f"Cells detected: {total_cells}\n"