MedhaCodes's picture
Update app/utils.py
065e49f verified
import string
import torch
# ----------------------------
# Vocabulary
# ----------------------------
# app/utils.py
import string
DIGITS = string.digits
LOWER = string.ascii_lowercase
UPPER = string.ascii_uppercase
BLANK_CHAR = "-"
CHARS = DIGITS + LOWER + UPPER + BLANK_CHAR
char2idx = {c: i for i, c in enumerate(CHARS)}
idx2char = {i: c for c, i in char2idx.items()}
# ----------------------------
# CTC Beam Search Decoder
# ----------------------------
def ctc_beam_search(logits, beam_width=5):
"""
logits: (B, T, C)
"""
probs = logits.softmax(2)[0] # (T, C)
T, C = probs.shape
beams = [("", 1.0)]
for t in range(T):
new_beams = {}
for prefix, score in beams:
for c in range(C):
p = probs[t, c].item()
if p < 1e-4:
continue
char = idx2char[c]
new_prefix = prefix if char == BLANK_CHAR else prefix + char
new_beams[new_prefix] = max(
new_beams.get(new_prefix, 0.0),
score * p
)
beams = sorted(
new_beams.items(),
key=lambda x: x[1],
reverse=True
)[:beam_width]
return beams
# ----------------------------
# Decode + Confidence
# ----------------------------
def decode_with_confidence(logits):
beams = ctc_beam_search(logits, beam_width=5)
best_text, best_score = beams[0]
# normalize confidence (simple & stable)
confidence = min(1.0, best_score * 10)
return best_text, round(confidence, 3)