File size: 2,170 Bytes
e141a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import unicodedata
from typing import List, Union

class OCRTokenizer:
    def __init__(self, vocab: dict):
        required_tokens = ['<PAD>', '<SOS>', '<EOS>', '<UNK>']
        for token in required_tokens:
            if token not in vocab:
                raise ValueError(f"Vocab must contain {token}")

        self.vocab = vocab
        self.idx2char = {idx: char for char, idx in vocab.items()}
        self.unk_token_id = vocab['<UNK>']
        self.special_tokens = {
            'pad': vocab['<PAD>'],
            'sos': vocab['<SOS>'],
            'eos': vocab['<EOS>']
        }

    def encode(self, text: str, max_length: int = 100) -> List[int]:
        if not isinstance(text, str):
            raise ValueError("Input must be a string")

        text = unicodedata.normalize('NFC', text.strip())
        tokens = [self.vocab.get(char, self.unk_token_id) for char in text]
        encoded = [self.special_tokens['sos']] + tokens[:max_length] + [self.special_tokens['eos']]
        return encoded

    def batch_encode(self, texts: List[str], max_length: int = 100) -> List[List[int]]:
        return [self.encode(text, max_length) for text in texts]

    def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
        decoded = []
        for id in token_ids:
            if skip_special_tokens and id in self.special_tokens.values():
                continue
            # Nếu token ID là <UNK> hoặc không tồn tại trong vocab
            if id == self.unk_token_id:
                decoded.append('?')
            elif id in self.idx2char:
                decoded.append(self.idx2char[id])
            else:
                decoded.append('?')  # Dùng '?' nếu gặp token không hợp lệ
        return ''.join(decoded)

    def batch_decode(self, batch_token_ids: List[List[int]], skip_special_tokens: bool = True) -> List[str]:
        return [self.decode(token_ids, skip_special_tokens) for token_ids in batch_token_ids]

    def __len__(self) -> int:
        return len(self.vocab)

    @property
    def vocab_size(self) -> int:
        return len(self)