taellinglin commited on
Commit
644f8c6
·
verified ·
1 Parent(s): c4d9a6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -15
app.py CHANGED
@@ -14,13 +14,15 @@ import matplotlib.pyplot as plt
14
  import math
15
  from datetime import datetime
16
  import re
17
-
 
 
18
  # --------- Globals --------- #
19
  CHARS = string.ascii_letters + string.digits + string.punctuation
20
- CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
21
- CHAR2IDX["<BLANK>"] = 0
22
- BLANK_IDX = 0
23
  IDX2CHAR = {v: k for k, v in CHAR2IDX.items()}
 
24
  IMAGE_HEIGHT = 32
25
  IMAGE_WIDTH = 128
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -87,7 +89,10 @@ class OCRModel(nn.Module):
87
  x, _ = self.rnn(x)
88
  x = self.fc(x)
89
  return x
90
-
 
 
 
91
 
92
  def sanitize_filename(name):
93
  return re.sub(r'[^a-zA-Z0-9_-]', '_', name)
@@ -239,31 +244,88 @@ def preprocess_image(image: Image.Image):
239
  return to_pil_image(padded)
240
 
241
 
242
- def predict_text(image: Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  if ocr_model is None:
244
  return "Please load or train a model first."
245
 
246
  processed = preprocess_image(image)
247
-
248
  transform = transforms.Compose([
249
  transforms.ToTensor(),
250
  transforms.Normalize((0.5,), (0.5,))
251
  ])
252
- img_tensor = transform(processed).unsqueeze(0).to(device) # (1, C, H, W)
253
 
 
254
  with torch.no_grad():
255
  output = ocr_model(img_tensor) # (B, T, C)
256
- log_probs = output.log_softmax(2).permute(1, 0, 2) # (T, B, C)
 
 
 
 
 
 
 
 
 
 
257
 
258
- pred = greedy_decode(log_probs) # should be a string now
 
 
259
 
260
- probs = log_probs.exp()
261
- max_probs = probs.max(2)[0].squeeze(1) # (T,)
262
- avg_conf = max_probs.mean().item()
263
 
264
- return f"Prediction: {pred}\nConfidence: {avg_conf:.2%}"
 
 
 
265
 
 
 
 
 
 
266
 
 
267
 
268
 
269
  # New helper function: generate label images grid
@@ -359,7 +421,7 @@ with gr.Blocks(css=custom_css) as demo:
359
 
360
  image_input = gr.Image(type="pil", label="Upload word strip")
361
  predict_btn = gr.Button("Predict")
362
- output_text = gr.Textbox(label="Recognized Text")
363
  model_status = gr.Textbox(label="Model Load Status")
364
 
365
  # Refresh dropdown choices
 
14
  import math
15
  from datetime import datetime
16
  import re
17
+ from difflib import SequenceMatcher
18
+ from termcolor import colored
19
+ from ctcdecode import CTCBeamDecoder
20
  # --------- Globals --------- #
21
  CHARS = string.ascii_letters + string.digits + string.punctuation
22
+ CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1
23
+ CHAR2IDX["<BLANK>"] = 0 # CTC blank
 
24
  IDX2CHAR = {v: k for k, v in CHAR2IDX.items()}
25
+ BLANK_IDX = 0
26
  IMAGE_HEIGHT = 32
27
  IMAGE_WIDTH = 128
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
89
  x, _ = self.rnn(x)
90
  x = self.fc(x)
91
  return x
92
+ def color_char(c, conf):
93
+ color_levels = ['\033[31m', '\033[33m', '\033[32m', '\033[36m', '\033[34m', '\033[35m', '\033[0m']
94
+ idx = min(int(conf * (len(color_levels) - 1)), len(color_levels) - 1)
95
+ return f"{color_levels[idx]}{c}\033[0m"
96
 
97
  def sanitize_filename(name):
98
  return re.sub(r'[^a-zA-Z0-9_-]', '_', name)
 
244
  return to_pil_image(padded)
245
 
246
 
247
+
248
+
249
+ # ROYGBIV color ramp (low → high confidence)
250
+ CONFIDENCE_COLORS = [
251
+ "#FF0000", # Red
252
+ "#FF7F00", # Orange
253
+ "#FFFF00", # Yellow
254
+ "#00FF00", # Green
255
+ "#00BFFF", # Sky Blue
256
+ "#0000FF", # Blue
257
+ "#8B00FF", # Violet
258
+ ]
259
+
260
+ def confidence_to_color(conf):
261
+ """
262
+ Map confidence (0.0–1.0) to a ROYGBIV-style hex color.
263
+ """
264
+ index = min(int(conf * (len(CONFIDENCE_COLORS) - 1)), len(CONFIDENCE_COLORS) - 1)
265
+ return CONFIDENCE_COLORS[index]
266
+
267
+ def color_char(c, conf):
268
+ """
269
+ Wrap character `c` in a span tag with color mapped from `conf`.
270
+ """
271
+ color = confidence_to_color(conf)
272
+ return f'<span style="color:{color}; font-size:32pt; font-weight:bold;">{c}</span>'
273
+
274
+
275
+
276
+ # Build decoder once (outside predict_text)
277
+ decoder = CTCBeamDecoder(
278
+ labels=[IDX2CHAR[i] for i in range(len(IDX2CHAR))],
279
+ blank_id=BLANK_IDX,
280
+ beam_width=10, # try 10–20 for best results
281
+ num_processes=4,
282
+ log_probs_input=True
283
+ )
284
+
285
+ def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False):
286
  if ocr_model is None:
287
  return "Please load or train a model first."
288
 
289
  processed = preprocess_image(image)
 
290
  transform = transforms.Compose([
291
  transforms.ToTensor(),
292
  transforms.Normalize((0.5,), (0.5,))
293
  ])
294
+ img_tensor = transform(processed).unsqueeze(0).to(device)
295
 
296
+ ocr_model.eval()
297
  with torch.no_grad():
298
  output = ocr_model(img_tensor) # (B, T, C)
299
+ log_probs = output.log_softmax(2) # (B, T, C)
300
+ output_lengths = torch.full((1,), log_probs.size(1), dtype=torch.int32)
301
+
302
+ # Beam decode
303
+ beam_results, beam_scores, timesteps, out_lens = decoder.decode(log_probs, output_lengths)
304
+
305
+ pred_indices = beam_results[0][0][:out_lens[0][0]].cpu().numpy().tolist()
306
+ pred_chars = [IDX2CHAR.get(idx, "?") for idx in pred_indices]
307
+
308
+ # Confidence estimation: mean probability of chosen path
309
+ avg_conf = torch.exp(beam_scores[0][0] / out_lens[0][0]).item()
310
 
311
+ # Colorize with fake uniform confidence (we don’t get per-char conf from decoder)
312
+ colorized_chars = [color_char(c, avg_conf) for c in pred_chars]
313
+ pretty_output = ''.join(colorized_chars)
314
 
315
+ pred_text = ''.join(pred_chars)
 
 
316
 
317
+ sim_score = ""
318
+ if ground_truth:
319
+ similarity = SequenceMatcher(None, ground_truth, pred_text).ratio()
320
+ sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}"
321
 
322
+ if debug:
323
+ print("Decoded Beam:", pred_text)
324
+ print("Beam Confidence:", avg_conf)
325
+ if ground_truth:
326
+ print("Ground Truth:", ground_truth)
327
 
328
+ return f"<strong>Prediction:</strong> {pretty_output}<br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}"
329
 
330
 
331
  # New helper function: generate label images grid
 
421
 
422
  image_input = gr.Image(type="pil", label="Upload word strip")
423
  predict_btn = gr.Button("Predict")
424
+ output_text = gr.HTML(label="Recognized Text")
425
  model_status = gr.Textbox(label="Model Load Status")
426
 
427
  # Refresh dropdown choices