| |
|
|
| import os |
| import json |
| from typing import List, Union |
| from collections import Counter |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| from MixTokenizer import sample_integer_points, NewLangTokenizer |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def get_mix_tokenizer(tokenizer_cls): |
|
|
| class MixTokenizer(tokenizer_cls): |
| """ |
| Combines Qwen2Tokenizer with an additional tokenizer for a custom language. |
| Allows mapping of new language tokens to composite representations. |
| """ |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| instance = super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
| instance.pretrained_model_name_or_path = pretrained_model_name_or_path |
| |
| script_dir = os.path.join(pretrained_model_name_or_path, "MixTokenizer") |
| json_path = os.path.join(script_dir, "extra_config.json") |
| with open(json_path, "r", encoding="utf-8") as f: |
| extra_config = json.load(f) |
|
|
| |
| new_path = os.path.join(script_dir, "new_tokenizer") |
| try: |
| print("Try to load AutoTokenizer from HF") |
| new_lang_tokenizer = AutoTokenizer.from_pretrained(new_path) |
| except Exception: |
| print("Fallback: use default WordLevel Tokenizer") |
| vocab_path = os.path.join(new_path, "vocab.json") |
| new_lang_tokenizer = NewLangTokenizer(vocab_file=vocab_path) |
|
|
|
|
| instance.new_lang_tokenizer = new_lang_tokenizer |
| level = extra_config.get("level", None) |
| frequency_id_files = extra_config.get("frequency_id_files", None) |
|
|
| if extra_config.get("mapping") and extra_config.get("used_ids"): |
| print(f"Mapping file and used ids are loaded, level ignored: {level}") |
| instance.mapping = extra_config["mapping"] |
| instance.level = len(instance.mapping[0]) |
| instance.zero_ids = extra_config["used_ids"] |
| instance.zero_dict = {zid: idx for idx, zid in enumerate(instance.zero_ids)} |
| instance.reverse_mapping = {tuple(point): idx for idx, point in enumerate(instance.mapping)} |
| elif frequency_id_files: |
| print(f"Mapping file and used ids do not all exist, using frequency_id_files, level={level}") |
| instance._prepare_mapping(frequency_id_files, level) |
| instance.save_to_json() |
| else: |
| raise ValueError( |
| "Ensure mapping and used_ids exist in config, or frequency_id_files and level exist in config." |
| ) |
| return instance |
| |
| def save_to_json(self): |
| script_dir = os.path.join(self.pretrained_model_name_or_path, "MixTokenizer") |
| json_path = os.path.join(script_dir, "extra_config.json") |
| cfg = {} |
| if os.path.exists(json_path): |
| with open(json_path, "r", encoding="utf-8") as f: |
| cfg = json.load(f) |
| if not cfg.get("mapping") or not cfg.get("used_ids"): |
| cfg["mapping"] = self.mapping |
| cfg["used_ids"] = self.zero_ids |
| with open(json_path, "w", encoding="utf-8") as f: |
| json.dump(cfg, f, indent=2) |
| |
| def _prepare_mapping(self, frequency_id_files: List[str], level: int | str) -> None: |
| """ |
| Build mapping between new language tokens and low-frequency Qwen tokens. |
| """ |
| if level == "mixed": |
| raise NotImplementedError("Decoding in 'mixed' level mode is not supported.") |
|
|
| |
| frequency_counter = Counter() |
| for file in frequency_id_files: |
| freq_data = torch.load(file) |
| frequency_counter.update(freq_data) |
|
|
| |
| for i in range(len(self)): |
| frequency_counter.update({i: 0}) |
|
|
| |
| frequency_list = sorted(frequency_counter.items(), key=lambda x: x[1]) |
|
|
| |
| zero_ids = [tid for tid, freq in frequency_list if freq == 0] |
| if not zero_ids: |
| raise ValueError("No zero-frequency tokens available for mapping.") |
|
|
| print(f"\033[91m[WORK]\033[0mFound {len(zero_ids)} zero-frequency tokens for mapping.") |
| print(f"\033[91m[WORK]\033[0mNew language vocab size: {len(self.new_lang_tokenizer)}, use level={level}.") |
| |
| max_lim = int(pow(len(zero_ids), level)) |
| if max_lim < len(self.new_lang_tokenizer): |
| raise ValueError(f"Increase 'level', max_lim = {max_lim}") |
|
|
| |
| points = sample_integer_points(L=len(zero_ids), K=level, N=len(self.new_lang_tokenizer)) |
|
|
| |
| self.mapping: List[list[int]] = [ |
| [zero_ids[x] for x in point] for point in points |
| ] |
| self.zero_ids = zero_ids |
| self.zero_dict = {zid: idx for idx, zid in enumerate(zero_ids)} |
| self.reverse_mapping = {point: idx for idx, point in enumerate(map(tuple, self.mapping))} |
|
|
| def _same_char_type(self, ch1: str, ch2: str) -> bool: |
| return self.new_lang_tokenizer.is_new_char(ch1) == self.new_lang_tokenizer.is_new_char(ch2) |
|
|
| def _same_id_type(self, id1: int, id2: int) -> bool: |
| return self.is_new_id(id1) == self.is_new_id(id2) |
|
|
| def is_new_id(self, token_id: int) -> bool: |
| return token_id in self.zero_dict |
|
|
| def _convert_ids_to_new_lang_ids(self, token_ids: List[int]) -> int | List[int]: |
| return self.reverse_mapping.get(tuple(token_ids)) |
|
|
| def tokenize(self, text: str, **kwargs) -> List[str]: |
| """ |
| Two-stage tokenization: |
| 1. Group characters by type (new_lang vs Qwen). |
| 2. Tokenize each segment using the appropriate tokenizer. |
| """ |
| sub_texts = [] |
| for ch in text: |
| if sub_texts and self._same_char_type(ch, sub_texts[-1][0][0]): |
| sub_texts[-1][0] += ch |
| else: |
| sub_texts.append([ch, self.new_lang_tokenizer.is_new_char(ch)]) |
|
|
| tokens = [] |
| for sub_text, is_new in sub_texts: |
| if not is_new: |
| tokens.extend(tokenizer_cls.tokenize(self, sub_text, **kwargs)) |
| else: |
| |
| tokens.extend(self.new_lang_tokenizer.tokenize(sub_text)) |
| |
| return tokens |
|
|
| def _convert_one_token_to_id(self, token: str) -> Union[int, List[int]]: |
| """ |
| Convert a token to one or more IDs depending on its type. |
| """ |
| if self.new_lang_tokenizer.is_new_char(token): |
| token_id = self.new_lang_tokenizer.tokenizer.token_to_id(token) |
| |
| |
| return self.mapping[token_id] |
| else: |
| return self._convert_token_to_id_with_added_voc(token) |
|
|
| def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: |
| if tokens is None: |
| return None |
| if isinstance(tokens, str): |
| return self._convert_one_token_to_id(tokens) |
|
|
| ids = [] |
| for token in tokens: |
| mapped = self._convert_one_token_to_id(token) |
| ids.extend(mapped if isinstance(mapped, list) else [mapped]) |
| return ids |
|
|
| def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: |
| """ |
| Decode a sequence of token IDs by segment type. |
| """ |
| def _decode_sub(ids: List[int]) -> str: |
| |
| if not ids: |
| return "" |
| if all(self.is_new_id(x) for x in ids): |
| assert len(ids) % self.level == 0, "Invalid new language token IDs length." |
| ids = [self._convert_ids_to_new_lang_ids(ids[i:i + self.level]) for i in range(0, len(ids), self.level)] |
| return self.new_lang_tokenizer.decode(ids) |
| return tokenizer_cls._decode(self, ids, **kwargs) |
|
|
| sub_text, buffer = "", [] |
| for tid in token_ids: |
| if buffer and self._same_id_type(tid, buffer[-1]): |
| buffer.append(tid) |
| else: |
| sub_text += _decode_sub(buffer) |
| buffer = [tid] |
| sub_text += _decode_sub(buffer) |
| return sub_text |
| |
| return MixTokenizer |
|
|
|
|
| |
| from transformers import Qwen2Tokenizer |
| tokenizer_cls = Qwen2Tokenizer |
|
|
| |
| globals()["MixTokenizer"] = get_mix_tokenizer(tokenizer_cls=tokenizer_cls) |
|
|