import string import torch # ---------------------------- # Vocabulary # ---------------------------- # app/utils.py import string DIGITS = string.digits LOWER = string.ascii_lowercase UPPER = string.ascii_uppercase BLANK_CHAR = "-" CHARS = DIGITS + LOWER + UPPER + BLANK_CHAR char2idx = {c: i for i, c in enumerate(CHARS)} idx2char = {i: c for c, i in char2idx.items()} # ---------------------------- # CTC Beam Search Decoder # ---------------------------- def ctc_beam_search(logits, beam_width=5): """ logits: (B, T, C) """ probs = logits.softmax(2)[0] # (T, C) T, C = probs.shape beams = [("", 1.0)] for t in range(T): new_beams = {} for prefix, score in beams: for c in range(C): p = probs[t, c].item() if p < 1e-4: continue char = idx2char[c] new_prefix = prefix if char == BLANK_CHAR else prefix + char new_beams[new_prefix] = max( new_beams.get(new_prefix, 0.0), score * p ) beams = sorted( new_beams.items(), key=lambda x: x[1], reverse=True )[:beam_width] return beams # ---------------------------- # Decode + Confidence # ---------------------------- def decode_with_confidence(logits): beams = ctc_beam_search(logits, beam_width=5) best_text, best_score = beams[0] # normalize confidence (simple & stable) confidence = min(1.0, best_score * 10) return best_text, round(confidence, 3)