import torch import re import unicodedata import py_vncorenlp from transformers import AutoTokenizer class MorphemeAwareTokenizer(AutoTokenizer): def __init__(self, pretrained_model_name="vinai/phobert-base", vncorenlp_dir='/content/vncorenlp', **kwargs): # Khởi tạo tokenizer HF gốc self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name, **kwargs) # Khởi tạo VnCoreNLP cho word segmentation self.rdrsegmenter = py_vncorenlp.VnCoreNLP( annotators=["wseg"], save_dir=vncorenlp_dir ) def __len__(self): # Trả về vocab size return self.vocab_size @classmethod def from_pretrained(cls, pretrained_model_name_or_path, vncorenlp_dir='/content/vncorenlp', **kwargs): """ Load tokenizer từ Hugging Face và giữ logic custom """ return cls(pretrained_model_name_or_path, vncorenlp_dir=vncorenlp_dir, **kwargs) @property def mask_token(self): return self.tokenizer.mask_token @property def pad_token(self): return self.tokenizer.pad_token @property def cls_token(self): return self.tokenizer.cls_token @property def sep_token(self): return self.tokenizer.sep_token @property def unk_token(self): return self.tokenizer.unk_token # ============================= # ✅ Bổ sung để tương thích với DataCollator # ============================= @property def mask_token_id(self): return self.tokenizer.mask_token_id @property def pad_token_id(self): return self.tokenizer.pad_token_id @property def cls_token_id(self): return self.tokenizer.cls_token_id @property def sep_token_id(self): return self.tokenizer.sep_token_id @property def vocab_size(self): return self.tokenizer.vocab_size def pad(self, encoded_inputs, padding=True, max_length=None, return_tensors=None, **kwargs): """ Cho phép DataCollatorForLanguageModeling sử dụng pad() như tokenizer Hugging Face. """ return self.tokenizer.pad( encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors, **kwargs ) def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Trả về mask cho special tokens (1 = special token, 0 = normal token). Cần thiết cho DataCollatorForLanguageModeling. """ return self.tokenizer.get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=already_has_special_tokens ) def convert_tokens_to_ids(self, tokens): """ Chuyển tokens thành IDs. Cần cho một số collator. """ return self.tokenizer.convert_tokens_to_ids(tokens) def convert_ids_to_tokens(self, ids, skip_special_tokens=False): """ Chuyển IDs thành tokens. """ return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) def to_bmes(self, text): """ Tạo danh sách (syllable, BMES-tag) từ text hoặc list[text]. Nếu là list -> trả về list[list[(syllable, tag)]] Nếu là str -> trả về list[(syllable, tag)] """ if isinstance(text, list): return [self.to_bmes(t) for t in text] if not isinstance(text, str): text = str(text) segmented = self.rdrsegmenter.word_segment(text) # Trường hợp output là list nhiều câu → gộp lại theo từng câu riêng if isinstance(segmented, list): sentences = segmented else: sentences = [segmented] bmes_list = [] for sent in sentences: words = sent.split() for word in words: sylls = word.split("_") n = len(sylls) if n == 1: bmes_list.append((sylls[0], 'S')) else: bmes_list.append((sylls[0], 'B')) for mid in sylls[1:-1]: bmes_list.append((mid, 'M')) bmes_list.append((sylls[-1], 'E')) return bmes_list def normalize_text(self, text): text = text.replace("@@", "").replace("▁", "").strip() text = unicodedata.normalize('NFD', text) text = ''.join([c for c in text if not unicodedata.combining(c)]) text = re.sub(r'[^\w\s]', '', text) return text.lower() def is_punctuation(self, text): normalized = re.sub(r'[^\w\s]', '', text).strip() return normalized == "" def align_bmes_to_subwords(self, bmes_list, subwords_list): """ Align BMES tags với subwords, xử lý các trường hợp: - Dấu câu dính với chữ (vd: 'c.', '3.') - Ký tự đặc biệt, tokens - Subword splitting phức tạp 🔧 FIX: Xử lý token bằng cách skip nó và tiếp tục alignment """ aligned_tags = [] syll_idx = 0 buffer_raw = "" subword_positions = [] i = 0 while i < len(subwords_list): sub = subwords_list[i] # Special tokens - luôn tag là 'S' if sub in ["", "", "", ""]: aligned_tags.append("S") i += 1 continue # 🔧 XỬ LÝ TOKEN if sub == "": # token là biểu diễn của 1 ký tự không được vocab nhận diện # Gán tag 'S' cho nó và bỏ qua 1 syllable trong bmes_list nếu có aligned_tags.append("S") # Nếu còn syllable, skip nó vì đã được thay thế bằng if syll_idx < len(bmes_list): syll_idx += 1 # Reset buffer để tránh cascade errors buffer_raw = "" subword_positions = [] i += 1 continue # Hết syllables - tag còn lại là 'S' if syll_idx >= len(bmes_list): aligned_tags.append("S") i += 1 continue # Lấy syllable hiện tại syll, tag = bmes_list[syll_idx] clean_sub = sub.replace("▁", "").replace("@@", "") # Normalize để so sánh normalized_syll = self.normalize_text(syll) # Case 1: Syllable là dấu câu thuần túy if self.is_punctuation(syll): # Kiểm tra xem subword có chứa dấu câu này không if clean_sub == syll or syll in clean_sub: aligned_tags.append("S") syll_idx += 1 i += 1 # Reset buffer nếu đang xử lý buffer_raw = "" subword_positions = [] continue # Case 2: Subword có dấu câu dính (vd: 'c.', 'i.') # Tách phần chữ và dấu câu word_part = "" punct_part = "" # Pattern để tách: chữ cái/số ở đầu, dấu câu ở cuối match = re.match(r'^([a-zA-ZÀ-ỹ0-9]+)([^\w]+)$', clean_sub, re.UNICODE) if match: word_part = match.group(1) punct_part = match.group(2) else: word_part = clean_sub punct_part = "" # Xử lý phần chữ if word_part: buffer_raw += word_part subword_positions.append(len(aligned_tags)) aligned_tags.append(tag) # Tag tạm thời normalized_buffer = self.normalize_text(buffer_raw) # Kiểm tra buffer có khớp với syllable chưa if normalized_buffer == normalized_syll: # Gán lại tags đúng cho tất cả subwords trong buffer n = len(subword_positions) if n > 1: if tag == 'B': aligned_tags[subword_positions[0]] = 'B' for pos in subword_positions[1:]: aligned_tags[pos] = 'M' elif tag == 'E': for pos in subword_positions[:-1]: aligned_tags[pos] = 'M' aligned_tags[subword_positions[-1]] = 'E' elif tag == 'M': for pos in subword_positions: aligned_tags[pos] = 'M' elif tag == 'S': for pos in subword_positions: aligned_tags[pos] = 'S' else: aligned_tags[subword_positions[0]] = tag # Reset buffer và tăng syllable index buffer_raw = "" subword_positions = [] syll_idx += 1 # Xử lý phần dấu câu nếu có if punct_part: # Kiểm tra syllable tiếp theo có phải dấu câu không if syll_idx < len(bmes_list): next_syll, next_tag = bmes_list[syll_idx] if self.is_punctuation(next_syll) or next_syll == punct_part: # Dấu câu này thuộc syllable tiếp theo, không thêm tag syll_idx += 1 i += 1 return aligned_tags def __call__(self, text, **kwargs): # Nếu là list → xử lý batch if isinstance(text, list): # 1. Tokenize cả batch bằng tokenizer gốc encoded = self.tokenizer( text, add_special_tokens=True, padding=True, truncation=True, return_tensors=kwargs.get("return_tensors", None), ) # 2. Tạo BMES tags cho từng câu BMES_MAP = {"B": 0, "M": 1, "E": 2, "S": 3} bmes_tags_list = [] for i, t in enumerate(text): bmes_list = self.to_bmes(t) subwords = self.tokenizer.convert_ids_to_tokens(encoded["input_ids"][i].tolist()) bmes_tags = self.align_bmes_to_subwords(bmes_list, subwords) # Chuyển sang tensor nếu cần if kwargs.get("return_tensors") == "pt": bmes_tags = torch.tensor([BMES_MAP[tag] for tag in bmes_tags]) bmes_tags_list.append(bmes_tags) # Padding BMES tags giống input_ids if kwargs.get("return_tensors") == "pt": max_len = encoded["input_ids"].shape[1] padded_bmes = [] for tags in bmes_tags_list: pad_len = max_len - tags.shape[0] if pad_len > 0: tags = torch.cat([tags, torch.full((pad_len,), BMES_MAP["S"])]) padded_bmes.append(tags) encoded["bmes_tags"] = torch.stack(padded_bmes) else: encoded["bmes_tags"] = bmes_tags_list return encoded # Nếu là string đơn → xử lý như cũ bmes_list = self.to_bmes(text) encoded = self.tokenizer(text, add_special_tokens=True, **kwargs) input_ids = encoded["input_ids"] if isinstance(input_ids, torch.Tensor): input_ids = input_ids.squeeze(0).tolist() elif isinstance(input_ids[0], list): input_ids = input_ids[0] subwords = self.tokenizer.convert_ids_to_tokens(input_ids) bmes_tags = self.align_bmes_to_subwords(bmes_list, subwords) if kwargs.get("return_tensors") == "pt": BMES_MAP = {"B": 0, "M": 1, "E": 2, "S": 3} bmes_tags = torch.tensor([BMES_MAP[t] for t in bmes_tags]).unsqueeze(0) encoded['bmes_tags'] = bmes_tags return encoded def save_pretrained(self, save_directory, **kwargs): return self.tokenizer.save_pretrained(save_directory, **kwargs)