Receipt_OCR / tokenizer.py
RickyGM15's picture
Upload folder using huggingface_hub
e141a7d verified
import unicodedata
from typing import List, Union
class OCRTokenizer:
def __init__(self, vocab: dict):
required_tokens = ['<PAD>', '<SOS>', '<EOS>', '<UNK>']
for token in required_tokens:
if token not in vocab:
raise ValueError(f"Vocab must contain {token}")
self.vocab = vocab
self.idx2char = {idx: char for char, idx in vocab.items()}
self.unk_token_id = vocab['<UNK>']
self.special_tokens = {
'pad': vocab['<PAD>'],
'sos': vocab['<SOS>'],
'eos': vocab['<EOS>']
}
def encode(self, text: str, max_length: int = 100) -> List[int]:
if not isinstance(text, str):
raise ValueError("Input must be a string")
text = unicodedata.normalize('NFC', text.strip())
tokens = [self.vocab.get(char, self.unk_token_id) for char in text]
encoded = [self.special_tokens['sos']] + tokens[:max_length] + [self.special_tokens['eos']]
return encoded
def batch_encode(self, texts: List[str], max_length: int = 100) -> List[List[int]]:
return [self.encode(text, max_length) for text in texts]
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
decoded = []
for id in token_ids:
if skip_special_tokens and id in self.special_tokens.values():
continue
# Nếu token ID là <UNK> hoặc không tồn tại trong vocab
if id == self.unk_token_id:
decoded.append('?')
elif id in self.idx2char:
decoded.append(self.idx2char[id])
else:
decoded.append('?') # Dùng '?' nếu gặp token không hợp lệ
return ''.join(decoded)
def batch_decode(self, batch_token_ids: List[List[int]], skip_special_tokens: bool = True) -> List[str]:
return [self.decode(token_ids, skip_special_tokens) for token_ids in batch_token_ids]
def __len__(self) -> int:
return len(self.vocab)
@property
def vocab_size(self) -> int:
return len(self)