import unicodedata from typing import List, Union class OCRTokenizer: def __init__(self, vocab: dict): required_tokens = ['', '', '', ''] for token in required_tokens: if token not in vocab: raise ValueError(f"Vocab must contain {token}") self.vocab = vocab self.idx2char = {idx: char for char, idx in vocab.items()} self.unk_token_id = vocab[''] self.special_tokens = { 'pad': vocab[''], 'sos': vocab[''], 'eos': vocab[''] } def encode(self, text: str, max_length: int = 100) -> List[int]: if not isinstance(text, str): raise ValueError("Input must be a string") text = unicodedata.normalize('NFC', text.strip()) tokens = [self.vocab.get(char, self.unk_token_id) for char in text] encoded = [self.special_tokens['sos']] + tokens[:max_length] + [self.special_tokens['eos']] return encoded def batch_encode(self, texts: List[str], max_length: int = 100) -> List[List[int]]: return [self.encode(text, max_length) for text in texts] def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: decoded = [] for id in token_ids: if skip_special_tokens and id in self.special_tokens.values(): continue # Nếu token ID là hoặc không tồn tại trong vocab if id == self.unk_token_id: decoded.append('?') elif id in self.idx2char: decoded.append(self.idx2char[id]) else: decoded.append('?') # Dùng '?' nếu gặp token không hợp lệ return ''.join(decoded) def batch_decode(self, batch_token_ids: List[List[int]], skip_special_tokens: bool = True) -> List[str]: return [self.decode(token_ids, skip_special_tokens) for token_ids in batch_token_ids] def __len__(self) -> int: return len(self.vocab) @property def vocab_size(self) -> int: return len(self)