taellinglin commited on
Commit
f9b53d5
·
verified ·
1 Parent(s): 216ee30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -27
app.py CHANGED
@@ -14,9 +14,11 @@ import matplotlib.pyplot as plt
14
  import math
15
  from datetime import datetime
16
  import re
17
- from difflib import SequenceMatcher
18
  from termcolor import colored
19
- from ctcdecode import CTCBeamDecoder
 
 
 
20
  # --------- Globals --------- #
21
  CHARS = string.ascii_letters + string.digits + string.punctuation
22
  CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1
@@ -28,7 +30,14 @@ IMAGE_WIDTH = 128
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  font_path = None
30
  ocr_model = None
 
 
31
 
 
 
 
 
 
32
 
33
  # --------- Dataset --------- #
34
  class OCRDataset(Dataset):
@@ -273,14 +282,7 @@ def color_char(c, conf):
273
 
274
 
275
 
276
- # Build decoder once (outside predict_text)
277
- decoder = CTCBeamDecoder(
278
- labels=[IDX2CHAR[i] for i in range(len(IDX2CHAR))],
279
- blank_id=BLANK_IDX,
280
- beam_width=10, # try 10–20 for best results
281
- num_processes=4,
282
- log_probs_input=True
283
- )
284
 
285
  def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False):
286
  if ocr_model is None:
@@ -291,37 +293,32 @@ def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = Fal
291
  transforms.ToTensor(),
292
  transforms.Normalize((0.5,), (0.5,))
293
  ])
294
- img_tensor = transform(processed).unsqueeze(0).to(device)
295
 
296
  ocr_model.eval()
297
  with torch.no_grad():
298
- output = ocr_model(img_tensor) # (B, T, C)
299
- log_probs = output.log_softmax(2) # (B, T, C)
300
- output_lengths = torch.full((1,), log_probs.size(1), dtype=torch.int32)
301
 
302
- # Beam decode
303
- beam_results, beam_scores, timesteps, out_lens = decoder.decode(log_probs, output_lengths)
304
 
305
- pred_indices = beam_results[0][0][:out_lens[0][0]].cpu().numpy().tolist()
306
- pred_chars = [IDX2CHAR.get(idx, "?") for idx in pred_indices]
 
 
307
 
308
- # Confidence estimation: mean probability of chosen path
309
- avg_conf = torch.exp(beam_scores[0][0] / out_lens[0][0]).item()
310
-
311
- # Colorize with fake uniform confidence (we don’t get per-char conf from decoder)
312
- colorized_chars = [color_char(c, avg_conf) for c in pred_chars]
313
  pretty_output = ''.join(colorized_chars)
314
 
315
- pred_text = ''.join(pred_chars)
316
-
317
  sim_score = ""
318
  if ground_truth:
319
  similarity = SequenceMatcher(None, ground_truth, pred_text).ratio()
320
  sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}"
321
 
322
  if debug:
323
- print("Decoded Beam:", pred_text)
324
- print("Beam Confidence:", avg_conf)
325
  if ground_truth:
326
  print("Ground Truth:", ground_truth)
327
 
 
14
  import math
15
  from datetime import datetime
16
  import re
 
17
  from termcolor import colored
18
+ from pyctcdecode import BeamSearchDecoderCTC, Alphabet
19
+ from difflib import SequenceMatcher
20
+
21
+
22
  # --------- Globals --------- #
23
  CHARS = string.ascii_letters + string.digits + string.punctuation
24
  CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  font_path = None
32
  ocr_model = None
33
+ # Create vocabulary list (ensure order matches your model’s output indices!)
34
+ labels = [IDX2CHAR.get(i, "") for i in range(len(IDX2CHAR))]
35
 
36
+ # Wrap in Alphabet
37
+ alphabet = Alphabet.build_alphabet(labels)
38
+
39
+ # Now initialize decoder correctly
40
+ decoder = BeamSearchDecoderCTC(alphabet)
41
 
42
  # --------- Dataset --------- #
43
  class OCRDataset(Dataset):
 
282
 
283
 
284
 
285
+
 
 
 
 
 
 
 
286
 
287
  def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False):
288
  if ocr_model is None:
 
293
  transforms.ToTensor(),
294
  transforms.Normalize((0.5,), (0.5,))
295
  ])
296
+ img_tensor = transform(processed).unsqueeze(0).to(device) # (1, C, H, W)
297
 
298
  ocr_model.eval()
299
  with torch.no_grad():
300
+ output = ocr_model(img_tensor) # (1, T, C)
301
+ log_probs = output.log_softmax(2)[0] # (T, C)
 
302
 
303
+ pred_text = decoder.decode(log_probs.cpu().numpy()) # Best beam path
 
304
 
305
+ # Confidence: mean max prob per timestep
306
+ probs = log_probs.exp()
307
+ max_probs = probs.max(dim=1)[0]
308
+ avg_conf = max_probs.mean().item()
309
 
310
+ # Color each character (uniform confidence for now)
311
+ colorized_chars = [color_char(c, avg_conf) for c in pred_text]
 
 
 
312
  pretty_output = ''.join(colorized_chars)
313
 
 
 
314
  sim_score = ""
315
  if ground_truth:
316
  similarity = SequenceMatcher(None, ground_truth, pred_text).ratio()
317
  sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}"
318
 
319
  if debug:
320
+ print("Decoded Text:", pred_text)
321
+ print("Average Confidence:", avg_conf)
322
  if ground_truth:
323
  print("Ground Truth:", ground_truth)
324