Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,13 +14,15 @@ import matplotlib.pyplot as plt
|
|
| 14 |
import math
|
| 15 |
from datetime import datetime
|
| 16 |
import re
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
# --------- Globals --------- #
|
| 19 |
CHARS = string.ascii_letters + string.digits + string.punctuation
|
| 20 |
-
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
|
| 21 |
-
CHAR2IDX["<BLANK>"] = 0
|
| 22 |
-
BLANK_IDX = 0
|
| 23 |
IDX2CHAR = {v: k for k, v in CHAR2IDX.items()}
|
|
|
|
| 24 |
IMAGE_HEIGHT = 32
|
| 25 |
IMAGE_WIDTH = 128
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -87,7 +89,10 @@ class OCRModel(nn.Module):
|
|
| 87 |
x, _ = self.rnn(x)
|
| 88 |
x = self.fc(x)
|
| 89 |
return x
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def sanitize_filename(name):
|
| 93 |
return re.sub(r'[^a-zA-Z0-9_-]', '_', name)
|
|
@@ -239,31 +244,88 @@ def preprocess_image(image: Image.Image):
|
|
| 239 |
return to_pil_image(padded)
|
| 240 |
|
| 241 |
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
if ocr_model is None:
|
| 244 |
return "Please load or train a model first."
|
| 245 |
|
| 246 |
processed = preprocess_image(image)
|
| 247 |
-
|
| 248 |
transform = transforms.Compose([
|
| 249 |
transforms.ToTensor(),
|
| 250 |
transforms.Normalize((0.5,), (0.5,))
|
| 251 |
])
|
| 252 |
-
img_tensor = transform(processed).unsqueeze(0).to(device)
|
| 253 |
|
|
|
|
| 254 |
with torch.no_grad():
|
| 255 |
output = ocr_model(img_tensor) # (B, T, C)
|
| 256 |
-
log_probs = output.log_softmax(2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
-
|
|
|
|
|
|
|
| 259 |
|
| 260 |
-
|
| 261 |
-
max_probs = probs.max(2)[0].squeeze(1) # (T,)
|
| 262 |
-
avg_conf = max_probs.mean().item()
|
| 263 |
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
|
|
|
| 267 |
|
| 268 |
|
| 269 |
# New helper function: generate label images grid
|
|
@@ -359,7 +421,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 359 |
|
| 360 |
image_input = gr.Image(type="pil", label="Upload word strip")
|
| 361 |
predict_btn = gr.Button("Predict")
|
| 362 |
-
output_text = gr.
|
| 363 |
model_status = gr.Textbox(label="Model Load Status")
|
| 364 |
|
| 365 |
# Refresh dropdown choices
|
|
|
|
| 14 |
import math
|
| 15 |
from datetime import datetime
|
| 16 |
import re
|
| 17 |
+
from difflib import SequenceMatcher
|
| 18 |
+
from termcolor import colored
|
| 19 |
+
from ctcdecode import CTCBeamDecoder
|
| 20 |
# --------- Globals --------- #
|
| 21 |
CHARS = string.ascii_letters + string.digits + string.punctuation
|
| 22 |
+
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1
|
| 23 |
+
CHAR2IDX["<BLANK>"] = 0 # CTC blank
|
|
|
|
| 24 |
IDX2CHAR = {v: k for k, v in CHAR2IDX.items()}
|
| 25 |
+
BLANK_IDX = 0
|
| 26 |
IMAGE_HEIGHT = 32
|
| 27 |
IMAGE_WIDTH = 128
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 89 |
x, _ = self.rnn(x)
|
| 90 |
x = self.fc(x)
|
| 91 |
return x
|
| 92 |
+
def color_char(c, conf):
|
| 93 |
+
color_levels = ['\033[31m', '\033[33m', '\033[32m', '\033[36m', '\033[34m', '\033[35m', '\033[0m']
|
| 94 |
+
idx = min(int(conf * (len(color_levels) - 1)), len(color_levels) - 1)
|
| 95 |
+
return f"{color_levels[idx]}{c}\033[0m"
|
| 96 |
|
| 97 |
def sanitize_filename(name):
|
| 98 |
return re.sub(r'[^a-zA-Z0-9_-]', '_', name)
|
|
|
|
| 244 |
return to_pil_image(padded)
|
| 245 |
|
| 246 |
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ROYGBIV color ramp (low → high confidence)
|
| 250 |
+
CONFIDENCE_COLORS = [
|
| 251 |
+
"#FF0000", # Red
|
| 252 |
+
"#FF7F00", # Orange
|
| 253 |
+
"#FFFF00", # Yellow
|
| 254 |
+
"#00FF00", # Green
|
| 255 |
+
"#00BFFF", # Sky Blue
|
| 256 |
+
"#0000FF", # Blue
|
| 257 |
+
"#8B00FF", # Violet
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
def confidence_to_color(conf):
|
| 261 |
+
"""
|
| 262 |
+
Map confidence (0.0–1.0) to a ROYGBIV-style hex color.
|
| 263 |
+
"""
|
| 264 |
+
index = min(int(conf * (len(CONFIDENCE_COLORS) - 1)), len(CONFIDENCE_COLORS) - 1)
|
| 265 |
+
return CONFIDENCE_COLORS[index]
|
| 266 |
+
|
| 267 |
+
def color_char(c, conf):
|
| 268 |
+
"""
|
| 269 |
+
Wrap character `c` in a span tag with color mapped from `conf`.
|
| 270 |
+
"""
|
| 271 |
+
color = confidence_to_color(conf)
|
| 272 |
+
return f'<span style="color:{color}; font-size:32pt; font-weight:bold;">{c}</span>'
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Build decoder once (outside predict_text)
|
| 277 |
+
decoder = CTCBeamDecoder(
|
| 278 |
+
labels=[IDX2CHAR[i] for i in range(len(IDX2CHAR))],
|
| 279 |
+
blank_id=BLANK_IDX,
|
| 280 |
+
beam_width=10, # try 10–20 for best results
|
| 281 |
+
num_processes=4,
|
| 282 |
+
log_probs_input=True
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False):
|
| 286 |
if ocr_model is None:
|
| 287 |
return "Please load or train a model first."
|
| 288 |
|
| 289 |
processed = preprocess_image(image)
|
|
|
|
| 290 |
transform = transforms.Compose([
|
| 291 |
transforms.ToTensor(),
|
| 292 |
transforms.Normalize((0.5,), (0.5,))
|
| 293 |
])
|
| 294 |
+
img_tensor = transform(processed).unsqueeze(0).to(device)
|
| 295 |
|
| 296 |
+
ocr_model.eval()
|
| 297 |
with torch.no_grad():
|
| 298 |
output = ocr_model(img_tensor) # (B, T, C)
|
| 299 |
+
log_probs = output.log_softmax(2) # (B, T, C)
|
| 300 |
+
output_lengths = torch.full((1,), log_probs.size(1), dtype=torch.int32)
|
| 301 |
+
|
| 302 |
+
# Beam decode
|
| 303 |
+
beam_results, beam_scores, timesteps, out_lens = decoder.decode(log_probs, output_lengths)
|
| 304 |
+
|
| 305 |
+
pred_indices = beam_results[0][0][:out_lens[0][0]].cpu().numpy().tolist()
|
| 306 |
+
pred_chars = [IDX2CHAR.get(idx, "?") for idx in pred_indices]
|
| 307 |
+
|
| 308 |
+
# Confidence estimation: mean probability of chosen path
|
| 309 |
+
avg_conf = torch.exp(beam_scores[0][0] / out_lens[0][0]).item()
|
| 310 |
|
| 311 |
+
# Colorize with fake uniform confidence (we don’t get per-char conf from decoder)
|
| 312 |
+
colorized_chars = [color_char(c, avg_conf) for c in pred_chars]
|
| 313 |
+
pretty_output = ''.join(colorized_chars)
|
| 314 |
|
| 315 |
+
pred_text = ''.join(pred_chars)
|
|
|
|
|
|
|
| 316 |
|
| 317 |
+
sim_score = ""
|
| 318 |
+
if ground_truth:
|
| 319 |
+
similarity = SequenceMatcher(None, ground_truth, pred_text).ratio()
|
| 320 |
+
sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}"
|
| 321 |
|
| 322 |
+
if debug:
|
| 323 |
+
print("Decoded Beam:", pred_text)
|
| 324 |
+
print("Beam Confidence:", avg_conf)
|
| 325 |
+
if ground_truth:
|
| 326 |
+
print("Ground Truth:", ground_truth)
|
| 327 |
|
| 328 |
+
return f"<strong>Prediction:</strong> {pretty_output}<br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}"
|
| 329 |
|
| 330 |
|
| 331 |
# New helper function: generate label images grid
|
|
|
|
| 421 |
|
| 422 |
image_input = gr.Image(type="pil", label="Upload word strip")
|
| 423 |
predict_btn = gr.Button("Predict")
|
| 424 |
+
output_text = gr.HTML(label="Recognized Text")
|
| 425 |
model_status = gr.Textbox(label="Model Load Status")
|
| 426 |
|
| 427 |
# Refresh dropdown choices
|