| import torch |
| import re |
|
|
| from collections import Counter |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
| class SentenceTokenizerConfig(PretrainedConfig): |
| model_type = "sentence_tokenizer" |
| def __init__( |
| self, |
| min_length=32, |
| max_length=64, |
| n_overlap=3, |
| roll=False, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.min_length = min_length |
| self.max_length = max_length |
| self.n_overlap = n_overlap |
| self.roll = roll |
|
|
| class SentenceTokenizer(PreTrainedModel): |
| config_class = SentenceTokenizerConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.temp_module = torch.nn.Parameter(torch.ones(1)) |
| self.min_length = config.min_length |
| self.max_length = config.max_length |
| self.n_overlap = config.n_overlap |
| self.roll = config.roll |
|
|
| def split_text_into_sentences(self, text): |
| split_text = re.split(r'([^가-힣] )', text) |
| split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else []) |
|
|
| return split_text |
|
|
| def merge_chunks(self, chunks): |
| merged_chunks = [] |
| buffer = "" |
|
|
| for chunk in chunks: |
| buffer += chunk |
| if len(buffer) > self.min_length: |
| merged_chunks.append(buffer) |
| buffer = "" |
|
|
| |
| if buffer: |
| merged_chunks.append(buffer) |
|
|
| return merged_chunks |
|
|
| def merge_chunks_reverse(self, chunks): |
| chunks_reverse = [] |
| for chunk in chunks[::-1]: |
| chunks_reverse.append(chunk[::-1]) |
| |
| merged_chunks = [] |
| buffer = "" |
|
|
| for chunk in chunks_reverse: |
| buffer += chunk |
| if len(buffer) > self.min_length: |
| merged_chunks.append(buffer) |
| buffer = "" |
|
|
| |
| if buffer: |
| merged_chunks.append(buffer) |
|
|
| res_merged_chunks = [] |
| for chunk in merged_chunks[::-1]: |
| res_merged_chunks.append(chunk[::-1]) |
|
|
| return res_merged_chunks |
| |
| def split_text(self, text): |
| words = self.split_space(text) |
| |
| |
| splitted_chunks = [] |
| buffer = [] |
| |
| for word in words: |
| buffer.append(word) |
| merged_text = ''.join(buffer) |
| |
| |
| if len(merged_text) > self.max_length: |
| |
| buffer.pop() |
| splitted_chunks.append(''.join(buffer)) |
| buffer = [''+word] |
| |
| |
| if buffer: |
| splitted_chunks.append(''.join(buffer)) |
| |
| return splitted_chunks |
|
|
| def tokenize(self, text): |
| splitted_chunks = [] |
| |
| sentences = self.split_text_into_sentences(text) |
| for chunk in sentences: |
| if len(chunk)>=self.max_length: |
| splitted_chunks.extend(self.split_text(chunk)) |
| else: |
| splitted_chunks.append(chunk) |
| merged_chunks = self.merge_chunks(splitted_chunks) |
| merged_chunks = self.merge_chunks_reverse(merged_chunks) |
|
|
| return merged_chunks |
|
|
| def split_space(self, text): |
| split_text = re.split(r'(\s+)', text) |
| filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])] |
| return filtered_text |
| |
| def overlap(self, chunks): |
| if not chunks: |
| return [] |
| if self.roll: |
| chunks = [chunks[-1]] + chunks + [chunks[0]] |
| res = [] |
| total_idx = 0 |
| for chunk_idx in range(len(chunks)-1): |
| chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1] |
| chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b) |
| chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap |
| for overlap_idx in range(self.n_overlap): |
| chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx]) |
| overlap = chunk_a_overlap+chunk_b_overlap |
| start = total_idx+len(chunk_a_past) |
| end = start + len(overlap) |
| res.append((start, end, overlap)) |
| total_idx += len(chunk_a) |
| res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1])) |
|
|
| return res |
|
|
| def decode_overlap(self, chunks): |
| if not chunks: |
| return "" |
| |
| |
| max_length = max(end for _, end, _ in chunks) |
| |
| |
| index_char_map = {i: [] for i in range(max_length)} |
| |
| |
| for start, end, chunk in chunks: |
| for i, char in enumerate(chunk): |
| index = start + i |
| if index < max_length: |
| index_char_map[index].append(char) |
| |
| |
| reconstructed_text = [] |
| for i in range(max_length): |
| most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0] |
| reconstructed_text.append(most_common_char) |
| res = "".join(reconstructed_text) |
| if self.roll: |
| res = res[len(chunks[0][2]):-len(chunks[-1][2])] |
| |
| return res |
|
|