File size: 12,910 Bytes
0d085ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import torch
import re
import unicodedata
import py_vncorenlp
from transformers import AutoTokenizer

class MorphemeAwareTokenizer(AutoTokenizer):
    def __init__(self, pretrained_model_name="vinai/phobert-base", vncorenlp_dir='/content/vncorenlp', **kwargs):
        # Khởi tạo tokenizer HF gốc
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name, **kwargs)

        # Khởi tạo VnCoreNLP cho word segmentation
        self.rdrsegmenter = py_vncorenlp.VnCoreNLP(
            annotators=["wseg"],
            save_dir=vncorenlp_dir
        )

    def __len__(self):
        # Trả về vocab size
        return self.vocab_size

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, vncorenlp_dir='/content/vncorenlp', **kwargs):
        """
        Load tokenizer từ Hugging Face và giữ logic custom
        """
        return cls(pretrained_model_name_or_path, vncorenlp_dir=vncorenlp_dir, **kwargs)
    

    @property
    def mask_token(self):
        return self.tokenizer.mask_token

    @property
    def pad_token(self):
        return self.tokenizer.pad_token

    @property
    def cls_token(self):
        return self.tokenizer.cls_token

    @property
    def sep_token(self):
        return self.tokenizer.sep_token

    @property
    def unk_token(self):
        return self.tokenizer.unk_token

    # =============================
    # ✅ Bổ sung để tương thích với DataCollator
    # =============================

    @property
    def mask_token_id(self):
        return self.tokenizer.mask_token_id

    @property
    def pad_token_id(self):
        return self.tokenizer.pad_token_id

    @property
    def cls_token_id(self):
        return self.tokenizer.cls_token_id

    @property
    def sep_token_id(self):
        return self.tokenizer.sep_token_id

    @property
    def vocab_size(self):
        return self.tokenizer.vocab_size

    def pad(self, encoded_inputs, padding=True, max_length=None, return_tensors=None, **kwargs):
        """
        Cho phép DataCollatorForLanguageModeling sử dụng pad() như tokenizer Hugging Face.
        """
        return self.tokenizer.pad(
            encoded_inputs,
            padding=padding,
            max_length=max_length,
            return_tensors=return_tensors,
            **kwargs
        )

    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        """
        Trả về mask cho special tokens (1 = special token, 0 = normal token).
        Cần thiết cho DataCollatorForLanguageModeling.
        """
        return self.tokenizer.get_special_tokens_mask(
            token_ids_0=token_ids_0,
            token_ids_1=token_ids_1,
            already_has_special_tokens=already_has_special_tokens
        )

    def convert_tokens_to_ids(self, tokens):
        """
        Chuyển tokens thành IDs. Cần cho một số collator.
        """
        return self.tokenizer.convert_tokens_to_ids(tokens)

    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
        """
        Chuyển IDs thành tokens.
        """
        return self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)

    def to_bmes(self, text):
        """
        Tạo danh sách (syllable, BMES-tag) từ text hoặc list[text].
        Nếu là list -> trả về list[list[(syllable, tag)]]
        Nếu là str  -> trả về list[(syllable, tag)]
        """
        if isinstance(text, list):
            return [self.to_bmes(t) for t in text]

        if not isinstance(text, str):
            text = str(text)

        segmented = self.rdrsegmenter.word_segment(text)
        
        # Trường hợp output là list nhiều câu → gộp lại theo từng câu riêng
        if isinstance(segmented, list):
            sentences = segmented
        else:
            sentences = [segmented]

        bmes_list = []

        for sent in sentences:
            words = sent.split()
            for word in words:
                sylls = word.split("_")
                n = len(sylls)
                if n == 1:
                    bmes_list.append((sylls[0], 'S'))
                else:
                    bmes_list.append((sylls[0], 'B'))
                    for mid in sylls[1:-1]:
                        bmes_list.append((mid, 'M'))
                    bmes_list.append((sylls[-1], 'E'))
        
        return bmes_list


    def normalize_text(self, text):
        text = text.replace("@@", "").replace("▁", "").strip()
        text = unicodedata.normalize('NFD', text)
        text = ''.join([c for c in text if not unicodedata.combining(c)])
        text = re.sub(r'[^\w\s]', '', text)
        return text.lower()

    def is_punctuation(self, text):
        normalized = re.sub(r'[^\w\s]', '', text).strip()
        return normalized == ""

    def align_bmes_to_subwords(self, bmes_list, subwords_list):
        """
        Align BMES tags với subwords, xử lý các trường hợp:
        - Dấu câu dính với chữ (vd: 'c.', '3.')
        - Ký tự đặc biệt, <unk> tokens
        - Subword splitting phức tạp
        
        🔧 FIX: Xử lý <unk> token bằng cách skip nó và tiếp tục alignment
        """
        aligned_tags = []
        syll_idx = 0
        buffer_raw = ""
        subword_positions = []
        
        i = 0
        while i < len(subwords_list):
            sub = subwords_list[i]
            
            # Special tokens - luôn tag là 'S'
            if sub in ["<s>", "</s>", "<pad>", "<mask>"]:
                aligned_tags.append("S")
                i += 1
                continue
            
            # 🔧 XỬ LÝ <unk> TOKEN
            if sub == "<unk>":
                # <unk> token là biểu diễn của 1 ký tự không được vocab nhận diện
                # Gán tag 'S' cho nó và bỏ qua 1 syllable trong bmes_list nếu có
                aligned_tags.append("S")
                
                # Nếu còn syllable, skip nó vì đã được thay thế bằng <unk>
                if syll_idx < len(bmes_list):
                    syll_idx += 1
                
                # Reset buffer để tránh cascade errors
                buffer_raw = ""
                subword_positions = []
                
                i += 1
                continue
            
            # Hết syllables - tag còn lại là 'S'
            if syll_idx >= len(bmes_list):
                aligned_tags.append("S")
                i += 1
                continue
            
            # Lấy syllable hiện tại
            syll, tag = bmes_list[syll_idx]
            clean_sub = sub.replace("▁", "").replace("@@", "")
            
            # Normalize để so sánh
            normalized_syll = self.normalize_text(syll)
            
            # Case 1: Syllable là dấu câu thuần túy
            if self.is_punctuation(syll):
                # Kiểm tra xem subword có chứa dấu câu này không
                if clean_sub == syll or syll in clean_sub:
                    aligned_tags.append("S")
                    syll_idx += 1
                    i += 1
                    # Reset buffer nếu đang xử lý
                    buffer_raw = ""
                    subword_positions = []
                    continue
            
            # Case 2: Subword có dấu câu dính (vd: 'c.', 'i.')
            # Tách phần chữ và dấu câu
            word_part = ""
            punct_part = ""
            
            # Pattern để tách: chữ cái/số ở đầu, dấu câu ở cuối
            match = re.match(r'^([a-zA-ZÀ-ỹ0-9]+)([^\w]+)$', clean_sub, re.UNICODE)
            if match:
                word_part = match.group(1)
                punct_part = match.group(2)
            else:
                word_part = clean_sub
                punct_part = ""
            
            # Xử lý phần chữ
            if word_part:
                buffer_raw += word_part
                subword_positions.append(len(aligned_tags))
                aligned_tags.append(tag)  # Tag tạm thời
                
                normalized_buffer = self.normalize_text(buffer_raw)
                
                # Kiểm tra buffer có khớp với syllable chưa
                if normalized_buffer == normalized_syll:
                    # Gán lại tags đúng cho tất cả subwords trong buffer
                    n = len(subword_positions)
                    if n > 1:
                        if tag == 'B':
                            aligned_tags[subword_positions[0]] = 'B'
                            for pos in subword_positions[1:]:
                                aligned_tags[pos] = 'M'
                        elif tag == 'E':
                            for pos in subword_positions[:-1]:
                                aligned_tags[pos] = 'M'
                            aligned_tags[subword_positions[-1]] = 'E'
                        elif tag == 'M':
                            for pos in subword_positions:
                                aligned_tags[pos] = 'M'
                        elif tag == 'S':
                            for pos in subword_positions:
                                aligned_tags[pos] = 'S'
                    else:
                        aligned_tags[subword_positions[0]] = tag
                    
                    # Reset buffer và tăng syllable index
                    buffer_raw = ""
                    subword_positions = []
                    syll_idx += 1
                    
                    # Xử lý phần dấu câu nếu có
                    if punct_part:
                        # Kiểm tra syllable tiếp theo có phải dấu câu không
                        if syll_idx < len(bmes_list):
                            next_syll, next_tag = bmes_list[syll_idx]
                            if self.is_punctuation(next_syll) or next_syll == punct_part:
                                # Dấu câu này thuộc syllable tiếp theo, không thêm tag
                                syll_idx += 1
            
            i += 1
        
        return aligned_tags

    def __call__(self, text, **kwargs):
        # Nếu là list → xử lý batch
        if isinstance(text, list):
            # 1. Tokenize cả batch bằng tokenizer gốc
            encoded = self.tokenizer(
                text,
                add_special_tokens=True,
                padding=True,
                truncation=True,
                return_tensors=kwargs.get("return_tensors", None),
            )

            # 2. Tạo BMES tags cho từng câu
            BMES_MAP = {"B": 0, "M": 1, "E": 2, "S": 3}
            bmes_tags_list = []
            for i, t in enumerate(text):
                bmes_list = self.to_bmes(t)
                subwords = self.tokenizer.convert_ids_to_tokens(encoded["input_ids"][i].tolist())
                bmes_tags = self.align_bmes_to_subwords(bmes_list, subwords)
                
                # Chuyển sang tensor nếu cần
                if kwargs.get("return_tensors") == "pt":
                    bmes_tags = torch.tensor([BMES_MAP[tag] for tag in bmes_tags])
                bmes_tags_list.append(bmes_tags)

            # Padding BMES tags giống input_ids
            if kwargs.get("return_tensors") == "pt":
                max_len = encoded["input_ids"].shape[1]
                padded_bmes = []
                for tags in bmes_tags_list:
                    pad_len = max_len - tags.shape[0]
                    if pad_len > 0:
                        tags = torch.cat([tags, torch.full((pad_len,), BMES_MAP["S"])])
                    padded_bmes.append(tags)
                encoded["bmes_tags"] = torch.stack(padded_bmes)
            else:
                encoded["bmes_tags"] = bmes_tags_list

            return encoded

        # Nếu là string đơn → xử lý như cũ
        bmes_list = self.to_bmes(text)
        encoded = self.tokenizer(text, add_special_tokens=True, **kwargs)

        input_ids = encoded["input_ids"]
        if isinstance(input_ids, torch.Tensor):
            input_ids = input_ids.squeeze(0).tolist()
        elif isinstance(input_ids[0], list):
            input_ids = input_ids[0]

        subwords = self.tokenizer.convert_ids_to_tokens(input_ids)
        bmes_tags = self.align_bmes_to_subwords(bmes_list, subwords)

        if kwargs.get("return_tensors") == "pt":
            BMES_MAP = {"B": 0, "M": 1, "E": 2, "S": 3}
            bmes_tags = torch.tensor([BMES_MAP[t] for t in bmes_tags]).unsqueeze(0)

        encoded['bmes_tags'] = bmes_tags
        return encoded
    
    def save_pretrained(self, save_directory, **kwargs):
        return self.tokenizer.save_pretrained(save_directory, **kwargs)