| import torch |
| import re |
| import unicodedata |
| import py_vncorenlp |
| from transformers import AutoTokenizer |
|
|
| class MorphemeAwareTokenizer(AutoTokenizer): |
| def __init__(self, pretrained_model_name="vinai/phobert-base", vncorenlp_dir='/content/vncorenlp', **kwargs): |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name, **kwargs) |
|
|
| |
| self.rdrsegmenter = py_vncorenlp.VnCoreNLP( |
| annotators=["wseg"], |
| save_dir=vncorenlp_dir |
| ) |
|
|
| def __len__(self): |
| |
| return self.vocab_size |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, vncorenlp_dir='/content/vncorenlp', **kwargs): |
| """ |
| Load tokenizer từ Hugging Face và giữ logic custom |
| """ |
| return cls(pretrained_model_name_or_path, vncorenlp_dir=vncorenlp_dir, **kwargs) |
| |
|
|
| @property |
| def mask_token(self): |
| return self.tokenizer.mask_token |
|
|
| @property |
| def pad_token(self): |
| return self.tokenizer.pad_token |
|
|
| @property |
| def cls_token(self): |
| return self.tokenizer.cls_token |
|
|
| @property |
| def sep_token(self): |
| return self.tokenizer.sep_token |
|
|
| @property |
| def unk_token(self): |
| return self.tokenizer.unk_token |
|
|
| |
| |
| |
|
|
| @property |
| def mask_token_id(self): |
| return self.tokenizer.mask_token_id |
|
|
| @property |
| def pad_token_id(self): |
| return self.tokenizer.pad_token_id |
|
|
| @property |
| def cls_token_id(self): |
| return self.tokenizer.cls_token_id |
|
|
| @property |
| def sep_token_id(self): |
| return self.tokenizer.sep_token_id |
|
|
| @property |
| def vocab_size(self): |
| return self.tokenizer.vocab_size |
|
|
| def pad(self, encoded_inputs, padding=True, max_length=None, return_tensors=None, **kwargs): |
| """ |
| Cho phép DataCollatorForLanguageModeling sử dụng pad() như tokenizer Hugging Face. |
| """ |
| return self.tokenizer.pad( |
| encoded_inputs, |
| padding=padding, |
| max_length=max_length, |
| return_tensors=return_tensors, |
| **kwargs |
| ) |
|
|
| def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): |
| """ |
| Trả về mask cho special tokens (1 = special token, 0 = normal token). |
| Cần thiết cho DataCollatorForLanguageModeling. |
| """ |
| return self.tokenizer.get_special_tokens_mask( |
| token_ids_0=token_ids_0, |
| token_ids_1=token_ids_1, |
| already_has_special_tokens=already_has_special_tokens |
| ) |
|
|
| def convert_tokens_to_ids(self, tokens): |
| """ |
| Chuyển tokens thành IDs. Cần cho một số collator. |
| """ |
| return self.tokenizer.convert_tokens_to_ids(tokens) |
|
|
| def convert_ids_to_tokens(self, ids, skip_special_tokens=False): |
| """ |
| Chuyển IDs thành tokens. |
| """ |
| return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) |
|
|
| def to_bmes(self, text): |
| """ |
| Tạo danh sách (syllable, BMES-tag) từ text hoặc list[text]. |
| Nếu là list -> trả về list[list[(syllable, tag)]] |
| Nếu là str -> trả về list[(syllable, tag)] |
| """ |
| if isinstance(text, list): |
| return [self.to_bmes(t) for t in text] |
|
|
| if not isinstance(text, str): |
| text = str(text) |
|
|
| segmented = self.rdrsegmenter.word_segment(text) |
| |
| |
| if isinstance(segmented, list): |
| sentences = segmented |
| else: |
| sentences = [segmented] |
|
|
| bmes_list = [] |
|
|
| for sent in sentences: |
| words = sent.split() |
| for word in words: |
| sylls = word.split("_") |
| n = len(sylls) |
| if n == 1: |
| bmes_list.append((sylls[0], 'S')) |
| else: |
| bmes_list.append((sylls[0], 'B')) |
| for mid in sylls[1:-1]: |
| bmes_list.append((mid, 'M')) |
| bmes_list.append((sylls[-1], 'E')) |
| |
| return bmes_list |
|
|
|
|
| def normalize_text(self, text): |
| text = text.replace("@@", "").replace("▁", "").strip() |
| text = unicodedata.normalize('NFD', text) |
| text = ''.join([c for c in text if not unicodedata.combining(c)]) |
| text = re.sub(r'[^\w\s]', '', text) |
| return text.lower() |
|
|
| def is_punctuation(self, text): |
| normalized = re.sub(r'[^\w\s]', '', text).strip() |
| return normalized == "" |
|
|
| def align_bmes_to_subwords(self, bmes_list, subwords_list): |
| """ |
| Align BMES tags với subwords, xử lý các trường hợp: |
| - Dấu câu dính với chữ (vd: 'c.', '3.') |
| - Ký tự đặc biệt, <unk> tokens |
| - Subword splitting phức tạp |
| |
| 🔧 FIX: Xử lý <unk> token bằng cách skip nó và tiếp tục alignment |
| """ |
| aligned_tags = [] |
| syll_idx = 0 |
| buffer_raw = "" |
| subword_positions = [] |
| |
| i = 0 |
| while i < len(subwords_list): |
| sub = subwords_list[i] |
| |
| |
| if sub in ["<s>", "</s>", "<pad>", "<mask>"]: |
| aligned_tags.append("S") |
| i += 1 |
| continue |
| |
| |
| if sub == "<unk>": |
| |
| |
| aligned_tags.append("S") |
| |
| |
| if syll_idx < len(bmes_list): |
| syll_idx += 1 |
| |
| |
| buffer_raw = "" |
| subword_positions = [] |
| |
| i += 1 |
| continue |
| |
| |
| if syll_idx >= len(bmes_list): |
| aligned_tags.append("S") |
| i += 1 |
| continue |
| |
| |
| syll, tag = bmes_list[syll_idx] |
| clean_sub = sub.replace("▁", "").replace("@@", "") |
| |
| |
| normalized_syll = self.normalize_text(syll) |
| |
| |
| if self.is_punctuation(syll): |
| |
| if clean_sub == syll or syll in clean_sub: |
| aligned_tags.append("S") |
| syll_idx += 1 |
| i += 1 |
| |
| buffer_raw = "" |
| subword_positions = [] |
| continue |
| |
| |
| |
| word_part = "" |
| punct_part = "" |
| |
| |
| match = re.match(r'^([a-zA-ZÀ-ỹ0-9]+)([^\w]+)$', clean_sub, re.UNICODE) |
| if match: |
| word_part = match.group(1) |
| punct_part = match.group(2) |
| else: |
| word_part = clean_sub |
| punct_part = "" |
| |
| |
| if word_part: |
| buffer_raw += word_part |
| subword_positions.append(len(aligned_tags)) |
| aligned_tags.append(tag) |
| |
| normalized_buffer = self.normalize_text(buffer_raw) |
| |
| |
| if normalized_buffer == normalized_syll: |
| |
| n = len(subword_positions) |
| if n > 1: |
| if tag == 'B': |
| aligned_tags[subword_positions[0]] = 'B' |
| for pos in subword_positions[1:]: |
| aligned_tags[pos] = 'M' |
| elif tag == 'E': |
| for pos in subword_positions[:-1]: |
| aligned_tags[pos] = 'M' |
| aligned_tags[subword_positions[-1]] = 'E' |
| elif tag == 'M': |
| for pos in subword_positions: |
| aligned_tags[pos] = 'M' |
| elif tag == 'S': |
| for pos in subword_positions: |
| aligned_tags[pos] = 'S' |
| else: |
| aligned_tags[subword_positions[0]] = tag |
| |
| |
| buffer_raw = "" |
| subword_positions = [] |
| syll_idx += 1 |
| |
| |
| if punct_part: |
| |
| if syll_idx < len(bmes_list): |
| next_syll, next_tag = bmes_list[syll_idx] |
| if self.is_punctuation(next_syll) or next_syll == punct_part: |
| |
| syll_idx += 1 |
| |
| i += 1 |
| |
| return aligned_tags |
|
|
| def __call__(self, text, **kwargs): |
| |
| if isinstance(text, list): |
| |
| encoded = self.tokenizer( |
| text, |
| add_special_tokens=True, |
| padding=True, |
| truncation=True, |
| return_tensors=kwargs.get("return_tensors", None), |
| ) |
|
|
| |
| BMES_MAP = {"B": 0, "M": 1, "E": 2, "S": 3} |
| bmes_tags_list = [] |
| for i, t in enumerate(text): |
| bmes_list = self.to_bmes(t) |
| subwords = self.tokenizer.convert_ids_to_tokens(encoded["input_ids"][i].tolist()) |
| bmes_tags = self.align_bmes_to_subwords(bmes_list, subwords) |
| |
| |
| if kwargs.get("return_tensors") == "pt": |
| bmes_tags = torch.tensor([BMES_MAP[tag] for tag in bmes_tags]) |
| bmes_tags_list.append(bmes_tags) |
|
|
| |
| if kwargs.get("return_tensors") == "pt": |
| max_len = encoded["input_ids"].shape[1] |
| padded_bmes = [] |
| for tags in bmes_tags_list: |
| pad_len = max_len - tags.shape[0] |
| if pad_len > 0: |
| tags = torch.cat([tags, torch.full((pad_len,), BMES_MAP["S"])]) |
| padded_bmes.append(tags) |
| encoded["bmes_tags"] = torch.stack(padded_bmes) |
| else: |
| encoded["bmes_tags"] = bmes_tags_list |
|
|
| return encoded |
|
|
| |
| bmes_list = self.to_bmes(text) |
| encoded = self.tokenizer(text, add_special_tokens=True, **kwargs) |
|
|
| input_ids = encoded["input_ids"] |
| if isinstance(input_ids, torch.Tensor): |
| input_ids = input_ids.squeeze(0).tolist() |
| elif isinstance(input_ids[0], list): |
| input_ids = input_ids[0] |
|
|
| subwords = self.tokenizer.convert_ids_to_tokens(input_ids) |
| bmes_tags = self.align_bmes_to_subwords(bmes_list, subwords) |
|
|
| if kwargs.get("return_tensors") == "pt": |
| BMES_MAP = {"B": 0, "M": 1, "E": 2, "S": 3} |
| bmes_tags = torch.tensor([BMES_MAP[t] for t in bmes_tags]).unsqueeze(0) |
|
|
| encoded['bmes_tags'] = bmes_tags |
| return encoded |