Spaces:
Build error
Build error
| import unicodedata | |
| from typing import List, Union | |
| class OCRTokenizer: | |
| def __init__(self, vocab: dict): | |
| required_tokens = ['<PAD>', '<SOS>', '<EOS>', '<UNK>'] | |
| for token in required_tokens: | |
| if token not in vocab: | |
| raise ValueError(f"Vocab must contain {token}") | |
| self.vocab = vocab | |
| self.idx2char = {idx: char for char, idx in vocab.items()} | |
| self.unk_token_id = vocab['<UNK>'] | |
| self.special_tokens = { | |
| 'pad': vocab['<PAD>'], | |
| 'sos': vocab['<SOS>'], | |
| 'eos': vocab['<EOS>'] | |
| } | |
| def encode(self, text: str, max_length: int = 100) -> List[int]: | |
| if not isinstance(text, str): | |
| raise ValueError("Input must be a string") | |
| text = unicodedata.normalize('NFC', text.strip()) | |
| tokens = [self.vocab.get(char, self.unk_token_id) for char in text] | |
| encoded = [self.special_tokens['sos']] + tokens[:max_length] + [self.special_tokens['eos']] | |
| return encoded | |
| def batch_encode(self, texts: List[str], max_length: int = 100) -> List[List[int]]: | |
| return [self.encode(text, max_length) for text in texts] | |
| def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: | |
| decoded = [] | |
| for id in token_ids: | |
| if skip_special_tokens and id in self.special_tokens.values(): | |
| continue | |
| # Nếu token ID là <UNK> hoặc không tồn tại trong vocab | |
| if id == self.unk_token_id: | |
| decoded.append('?') | |
| elif id in self.idx2char: | |
| decoded.append(self.idx2char[id]) | |
| else: | |
| decoded.append('?') # Dùng '?' nếu gặp token không hợp lệ | |
| return ''.join(decoded) | |
| def batch_decode(self, batch_token_ids: List[List[int]], skip_special_tokens: bool = True) -> List[str]: | |
| return [self.decode(token_ids, skip_special_tokens) for token_ids in batch_token_ids] | |
| def __len__(self) -> int: | |
| return len(self.vocab) | |
| def vocab_size(self) -> int: | |
| return len(self) |