File size: 914 Bytes
ada63c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import List
from src.config import cfg

itos = ["<blank>"] + list(cfg.chars)

stoi = {c: i+1 for i,c in enumerate(cfg.chars)}

def encode_text(text: str) -> List[int]:
    return [stoi[c] for c in text]

def decode_indices(indices: List[int]) -> str:
    return "".join(itos[i] for i in indices if i != 0)

def ctc_greedy_decode(logits) -> List[str]:
    """
    Greedy CTC decode for a batch.
    logits: torch.Tensor of shape [T, B, V] (before softmax or log_softmax).
    Returns: list of B decoded strings.
    """
    import torch
    pred = logits.argmax(dim=-1)
    B = pred.shape[1]
    decoded = []
    for b in range(B):
        prev = -1
        chars = []
        for t in pred[:,b].tolist():
            if t!=0 and t!= prev:
                chars.append(itos[t])
            prev = t
        decoded.append("".join(chars))
    return decoded

def vocab_size() -> int:
    return len(itos)