| import torch |
|
|
| from collections import deque |
| from jamotools import split_syllables, join_jamos |
| from transformers import PretrainedConfig, PreTrainedModel, AutoTokenizer |
|
|
| class HangulTokenizerConfig(PretrainedConfig): |
| model_type = "hangul_tokenizer" |
| |
| def __init__( |
| self, |
| base_tokenizer_name='unsloth/gemma-2-2b', |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.base_tokenizer_name = base_tokenizer_name |
|
|
|
|
| class HangulTokenizer(PreTrainedModel): |
| config_class = HangulTokenizerConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.temp_module = torch.nn.Parameter(torch.ones(1)) |
| self.base_tokenizer = AutoTokenizer.from_pretrained(config.base_tokenizer_name) |
| self.base_tokenizer.pad_token_id = 128 |
| self.base_tokenizer.pad_token = self.base_tokenizer.decode([self.base_tokenizer.pad_token_id]) |
| self.space_token_id = self.base_tokenizer.encode(' ', add_special_tokens=False)[-1] |
| char_start, char_end = 0xAC00, 0xD7A3 |
| self.kor_chars = list(set([chr(code) for code in range(char_start, char_end + 1)])) |
| self.char_3ids = [] |
| self.char_1ids = [] |
| for kor_char in self.kor_chars: |
| ids = self.base_tokenizer.encode(kor_char, add_special_tokens=False) |
| if len(ids)==3: |
| self.char_3ids.append(ids) |
| else: |
| ids = ids+2*[self.base_tokenizer.pad_token_id] |
| self.char_1ids.append(ids) |
| self.chos = ['γ±', 'γ²', 'γ΄', 'γ·', 'γΈ', 'γΉ', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
'] |
| self.joongs = ['γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
‘', 'γ
’', 'γ
£'] |
| self.jongs = [self.base_tokenizer.pad_token, 'γ±', 'γ²', 'γ³', 'γ΄', 'γ΅', 'γΆ', 'γ·', 'γΉ', 'γΊ', 'γ»', 'γΌ', 'γ½', 'γΎ', 'γΏ', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
', 'γ
'] |
| jamos = list(set(self.chos) | set(self.joongs) | set(self.jongs)) |
| jamo_ids = self.base_tokenizer(jamos, add_special_tokens=False)['input_ids'] |
| self.jamo_to_id = {jamo: jamo_id[-1] for jamo, jamo_id in zip(jamos, jamo_ids)} |
| self.cho_ids = [self.jamo_to_id[cho] for cho in self.chos] |
| self.joong_ids = [self.jamo_to_id[joong] for joong in self.joongs] |
| self.jong_ids = [self.jamo_to_id[jong] for jong in self.jongs] |
| self.id_to_jamo = {jamo_id: jamo for jamo, jamo_id in self.jamo_to_id.items()} |
|
|
| def encode_jamo(self, sentence): |
| encoded_ids = [] |
| token_type_ids = [] |
| past_chars = '' |
| for char in sentence: |
| if char in self.kor_chars: |
| if past_chars: |
| past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
| encoded_ids.extend(past_chars_encoded) |
| token_type_ids.extend([0]*len(past_chars_encoded)) |
| past_chars='' |
| char_splitted = list(split_syllables(char))[:3] |
| char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token] |
| cho, joong, jong = char_splitted |
| encoded_ids.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]]) |
| token_type_ids.extend([1,2,3]) |
| else: |
| past_chars = past_chars+char |
| if past_chars: |
| past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
| encoded_ids.extend(past_chars_encoded) |
| token_type_ids.extend([0]*len(past_chars_encoded)) |
| return encoded_ids, token_type_ids |
|
|
| def decode_jamo(self, encoded_ids, token_type_ids): |
| encoded_ids = deque(encoded_ids) |
| token_type_ids = deque(token_type_ids) |
| decoded = [] |
| past_ids = [] |
| while len(encoded_ids): |
| encoded_id = encoded_ids.popleft() |
| token_type_id = token_type_ids.popleft() |
| if token_type_id==0: |
| past_ids.append(encoded_id) |
| else: |
| decoded.append(self.base_tokenizer.decode(past_ids)) |
| past_ids = [] |
| cho_id = encoded_id |
| joong_id = encoded_ids.popleft() |
| jong_id = encoded_ids.popleft() |
| token_type_ids.popleft() |
| token_type_ids.popleft() |
| char = join_jamos([self.id_to_jamo[cho_id], self.id_to_jamo[joong_id], self.id_to_jamo[jong_id]])[:1] |
| decoded.append(char) |
| decoded.append(self.base_tokenizer.decode(past_ids)) |
| return ''.join(decoded) |
|
|
| def encode_char(self, sentence): |
| encoded_ids = [] |
| token_type_ids = [] |
| past_chars = '' |
| for char in sentence: |
| if char in self.kor_chars: |
| if past_chars: |
| past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
| encoded_ids.extend(past_chars_encoded) |
| token_type_ids.extend([0]*len(past_chars_encoded)) |
| past_chars='' |
| encoded_id = self.base_tokenizer.encode(char, add_special_tokens=False) |
| encoded_id = encoded_id + (3-len(encoded_id)) * [self.base_tokenizer.pad_token_id] |
| encoded_ids.extend(encoded_id) |
| token_type_ids.extend([4,4,4]) |
| else: |
| past_chars = past_chars+char |
| if past_chars: |
| past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) |
| encoded_ids.extend(past_chars_encoded) |
| token_type_ids.extend([0]*len(past_chars_encoded)) |
| return encoded_ids, token_type_ids |
|
|
| def decode_char(self, encoded_ids, token_type_ids): |
| encoded_ids = deque(encoded_ids) |
| token_type_ids = deque(token_type_ids) |
| decoded = [] |
| past_ids = [] |
| while len(encoded_ids): |
| encoded_id = encoded_ids.popleft() |
| token_type_id = token_type_ids.popleft() |
| if token_type_id==0: |
| past_ids.append(encoded_id) |
| else: |
| decoded.append(self.base_tokenizer.decode(past_ids)) |
| past_ids = [] |
| id1 = encoded_id |
| id2 = encoded_ids.popleft() |
| id3 = encoded_ids.popleft() |
| token_type_ids.popleft() |
| token_type_ids.popleft() |
| [id1, id2, id3] |
| char = self.base_tokenizer.decode([id1, id2, id3])[:1] |
| decoded.append(char) |
| decoded.append(self.base_tokenizer.decode(past_ids)) |
| return ''.join(decoded) |
|
|
| def encode_jamo_from_char_encoded(self, encoded_ids, token_type_ids): |
| encoded_ids = deque(encoded_ids) |
| token_type_ids = deque(token_type_ids) |
| encoded_ids_new = [] |
| token_type_ids_new = [] |
| while len(encoded_ids): |
| encoded_id = encoded_ids.popleft() |
| token_type_id = token_type_ids.popleft() |
| if token_type_id==0: |
| encoded_ids_new.append(encoded_id) |
| token_type_ids_new.append(token_type_id) |
| else: |
| encoded_id2 = encoded_ids.popleft() |
| encoded_id3 = encoded_ids.popleft() |
| token_type_ids.popleft() |
| token_type_ids.popleft() |
| char = self.base_tokenizer.decode([encoded_id, encoded_id2, encoded_id3])[0] |
| char_splitted = list(split_syllables(char))[:3] |
| char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token] |
| cho, joong, jong = char_splitted |
| encoded_ids_new.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]]) |
| token_type_ids_new.extend([1,2,3]) |
| return encoded_ids_new, token_type_ids_new |
|
|
| def batch_encode_char(self, sentences): |
| input_ids = [] |
| attention_mask = [] |
| token_type_ids = [] |
| for sentence in sentences: |
| input_ids_row, token_type_id = self.encode_char(sentence) |
| input_ids.append(input_ids_row) |
| token_type_ids.append(token_type_id) |
| max_length = max(list(map(len, input_ids))) |
| for i in range(len(sentences)): |
| input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id] |
| attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]]) |
| token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0] |
| return ( |
| torch.LongTensor(input_ids), |
| torch.LongTensor(attention_mask), |
| torch.LongTensor(token_type_ids) |
| ) |
|
|
| def batch_encode_jamo_from_char_encoded(self, batch_encoded_ids, batch_token_type_ids): |
| input_ids = [] |
| attention_mask = [] |
| token_type_ids_new = [] |
| for encoded_ids, token_type_ids in zip(batch_encoded_ids, batch_token_type_ids): |
| encoded_ids_row, token_type_ids_row = self.encode_jamo_from_char_encoded(encoded_ids, token_type_ids) |
| attention_mask.append([1 if encoded_id!=self.base_tokenizer.eos_token_id else 0 for encoded_id in encoded_ids_row]) |
| input_ids.append(encoded_ids_row) |
| token_type_ids_new.append(token_type_ids_row) |
| |
| return ( |
| torch.LongTensor(input_ids), |
| torch.LongTensor(attention_mask), |
| torch.LongTensor(token_type_ids_new), |
| ) |
|
|
| def batch_encode_jamo(self, sentences): |
| input_ids = [] |
| attention_mask = [] |
| token_type_ids = [] |
| for sentence in sentences: |
| input_ids_row, token_type_id = self.encode_jamo(sentence) |
| input_ids.append(input_ids_row) |
| token_type_ids.append(token_type_id) |
| max_length = max(list(map(len, input_ids))) |
| |
| for i in range(len(sentences)): |
| input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id] |
| attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]]) |
| token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0] |
|
|
| return ( |
| torch.LongTensor(input_ids), |
| torch.LongTensor(attention_mask), |
| torch.LongTensor(token_type_ids), |
| ) |
|
|