Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,9 +14,11 @@ import matplotlib.pyplot as plt
|
|
| 14 |
import math
|
| 15 |
from datetime import datetime
|
| 16 |
import re
|
| 17 |
-
from difflib import SequenceMatcher
|
| 18 |
from termcolor import colored
|
| 19 |
-
from
|
|
|
|
|
|
|
|
|
|
| 20 |
# --------- Globals --------- #
|
| 21 |
CHARS = string.ascii_letters + string.digits + string.punctuation
|
| 22 |
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1
|
|
@@ -28,7 +30,14 @@ IMAGE_WIDTH = 128
|
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
font_path = None
|
| 30 |
ocr_model = None
|
|
|
|
|
|
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# --------- Dataset --------- #
|
| 34 |
class OCRDataset(Dataset):
|
|
@@ -273,14 +282,7 @@ def color_char(c, conf):
|
|
| 273 |
|
| 274 |
|
| 275 |
|
| 276 |
-
|
| 277 |
-
decoder = CTCBeamDecoder(
|
| 278 |
-
labels=[IDX2CHAR[i] for i in range(len(IDX2CHAR))],
|
| 279 |
-
blank_id=BLANK_IDX,
|
| 280 |
-
beam_width=10, # try 10–20 for best results
|
| 281 |
-
num_processes=4,
|
| 282 |
-
log_probs_input=True
|
| 283 |
-
)
|
| 284 |
|
| 285 |
def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False):
|
| 286 |
if ocr_model is None:
|
|
@@ -291,37 +293,32 @@ def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = Fal
|
|
| 291 |
transforms.ToTensor(),
|
| 292 |
transforms.Normalize((0.5,), (0.5,))
|
| 293 |
])
|
| 294 |
-
img_tensor = transform(processed).unsqueeze(0).to(device)
|
| 295 |
|
| 296 |
ocr_model.eval()
|
| 297 |
with torch.no_grad():
|
| 298 |
-
output = ocr_model(img_tensor)
|
| 299 |
-
log_probs = output.log_softmax(2)
|
| 300 |
-
output_lengths = torch.full((1,), log_probs.size(1), dtype=torch.int32)
|
| 301 |
|
| 302 |
-
|
| 303 |
-
beam_results, beam_scores, timesteps, out_lens = decoder.decode(log_probs, output_lengths)
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
# Colorize with fake uniform confidence (we don’t get per-char conf from decoder)
|
| 312 |
-
colorized_chars = [color_char(c, avg_conf) for c in pred_chars]
|
| 313 |
pretty_output = ''.join(colorized_chars)
|
| 314 |
|
| 315 |
-
pred_text = ''.join(pred_chars)
|
| 316 |
-
|
| 317 |
sim_score = ""
|
| 318 |
if ground_truth:
|
| 319 |
similarity = SequenceMatcher(None, ground_truth, pred_text).ratio()
|
| 320 |
sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}"
|
| 321 |
|
| 322 |
if debug:
|
| 323 |
-
print("Decoded
|
| 324 |
-
print("
|
| 325 |
if ground_truth:
|
| 326 |
print("Ground Truth:", ground_truth)
|
| 327 |
|
|
|
|
| 14 |
import math
|
| 15 |
from datetime import datetime
|
| 16 |
import re
|
|
|
|
| 17 |
from termcolor import colored
|
| 18 |
+
from pyctcdecode import BeamSearchDecoderCTC, Alphabet
|
| 19 |
+
from difflib import SequenceMatcher
|
| 20 |
+
|
| 21 |
+
|
| 22 |
# --------- Globals --------- #
|
| 23 |
CHARS = string.ascii_letters + string.digits + string.punctuation
|
| 24 |
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1
|
|
|
|
| 30 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
font_path = None
|
| 32 |
ocr_model = None
|
| 33 |
+
# Create vocabulary list (ensure order matches your model’s output indices!)
|
| 34 |
+
labels = [IDX2CHAR.get(i, "") for i in range(len(IDX2CHAR))]
|
| 35 |
|
| 36 |
+
# Wrap in Alphabet
|
| 37 |
+
alphabet = Alphabet.build_alphabet(labels)
|
| 38 |
+
|
| 39 |
+
# Now initialize decoder correctly
|
| 40 |
+
decoder = BeamSearchDecoderCTC(alphabet)
|
| 41 |
|
| 42 |
# --------- Dataset --------- #
|
| 43 |
class OCRDataset(Dataset):
|
|
|
|
| 282 |
|
| 283 |
|
| 284 |
|
| 285 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False):
|
| 288 |
if ocr_model is None:
|
|
|
|
| 293 |
transforms.ToTensor(),
|
| 294 |
transforms.Normalize((0.5,), (0.5,))
|
| 295 |
])
|
| 296 |
+
img_tensor = transform(processed).unsqueeze(0).to(device) # (1, C, H, W)
|
| 297 |
|
| 298 |
ocr_model.eval()
|
| 299 |
with torch.no_grad():
|
| 300 |
+
output = ocr_model(img_tensor) # (1, T, C)
|
| 301 |
+
log_probs = output.log_softmax(2)[0] # (T, C)
|
|
|
|
| 302 |
|
| 303 |
+
pred_text = decoder.decode(log_probs.cpu().numpy()) # Best beam path
|
|
|
|
| 304 |
|
| 305 |
+
# Confidence: mean max prob per timestep
|
| 306 |
+
probs = log_probs.exp()
|
| 307 |
+
max_probs = probs.max(dim=1)[0]
|
| 308 |
+
avg_conf = max_probs.mean().item()
|
| 309 |
|
| 310 |
+
# Color each character (uniform confidence for now)
|
| 311 |
+
colorized_chars = [color_char(c, avg_conf) for c in pred_text]
|
|
|
|
|
|
|
|
|
|
| 312 |
pretty_output = ''.join(colorized_chars)
|
| 313 |
|
|
|
|
|
|
|
| 314 |
sim_score = ""
|
| 315 |
if ground_truth:
|
| 316 |
similarity = SequenceMatcher(None, ground_truth, pred_text).ratio()
|
| 317 |
sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}"
|
| 318 |
|
| 319 |
if debug:
|
| 320 |
+
print("Decoded Text:", pred_text)
|
| 321 |
+
print("Average Confidence:", avg_conf)
|
| 322 |
if ground_truth:
|
| 323 |
print("Ground Truth:", ground_truth)
|
| 324 |
|