| """ |
| 可被AutoTokenizer自动加载的CIF Tokenizer实现(简化版) |
| |
| 通过在tokenize前预处理、decode后后处理的方式,避免复杂的方法签名问题 |
| """ |
| import os |
| import re |
| import json |
| from typing import List, Optional |
| from transformers import PreTrainedTokenizerFast |
|
|
|
|
| |
| SPACEGROUPS_PATH = os.path.join(os.path.dirname(__file__), "spacegroups.txt") |
| if os.path.exists(SPACEGROUPS_PATH): |
| with open(SPACEGROUPS_PATH, "rt") as f: |
| SPACE_GROUPS = [sg.strip() for sg in f.readlines()] |
| else: |
| SPACE_GROUPS = [] |
|
|
|
|
| class CIFTokenizerFast(PreTrainedTokenizerFast): |
| """ |
| 可被AutoTokenizer自动加载的CIF专用Tokenizer |
| |
| 继承自PreTrainedTokenizerFast,完全兼容Hugging Face生态系统 |
| |
| 使用说明: |
| - 通过 auto_map 配置,可被 AutoTokenizer 自动加载 |
| - 需要将 cif_tokenizer_fast.py 放在模型目录中 |
| - 加载时需要 trust_remote_code=True |
| """ |
| |
| def __init__(self, *args, space_groups: Optional[List[str]] = None, **kwargs): |
| """初始化""" |
| super().__init__(*args, **kwargs) |
| self.space_groups = space_groups if space_groups is not None else SPACE_GROUPS |
| self._build_spacegroup_pattern() |
| |
| def _build_spacegroup_pattern(self): |
| """构建空间群正则表达式模式""" |
| if self.space_groups: |
| sorted_sgs = sorted(self.space_groups, key=len, reverse=True) |
| self.spacegroup_pattern = "|".join([re.escape(sg) for sg in sorted_sgs]) |
| else: |
| self.spacegroup_pattern = None |
| |
| def preprocess_cif(self, text: str, single_spaces: bool = True) -> str: |
| """ |
| 预处理CIF文本 |
| |
| Args: |
| text: 原始CIF文本 |
| single_spaces: 是否规范化空格 |
| |
| Returns: |
| 预处理后的文本 |
| """ |
| if not isinstance(text, str): |
| return text |
| |
| |
| if single_spaces: |
| text = re.sub(r'[ \t]+', ' ', text) |
| |
| |
| if self.spacegroup_pattern: |
| pattern = fr'(_symmetry_space_group_name_H-M\s*[\'"]?\s*({self.spacegroup_pattern})\s*[\'"]?)' |
| |
| def add_sg_suffix(match): |
| full_match = match.group(0) |
| for sg in self.space_groups: |
| if sg in full_match and not full_match.endswith('_sg'): |
| return full_match.replace(sg, sg + '_sg', 1) |
| return full_match |
| |
| text = re.sub(pattern, add_sg_suffix, text) |
| |
| return text |
| |
| def postprocess_cif(self, text: str) -> str: |
| """后处理:移除_sg后缀""" |
| if isinstance(text, str) and self.space_groups: |
| for sg in self.space_groups: |
| text = text.replace(sg + '_sg', sg) |
| return text |
| |
| def __call__(self, *args, **kwargs): |
| """ |
| 重写__call__方法,强制移除token_type_ids |
| 因为Qwen模型不使用token_type_ids |
| """ |
| |
| result = super().__call__(*args, **kwargs) |
| |
| |
| if isinstance(result, dict) and 'token_type_ids' in result: |
| result.pop('token_type_ids') |
| |
| return result |
| |
| |
| def decode(self, *args, **kwargs): |
| """解码后后处理""" |
| result = super().decode(*args, **kwargs) |
| return self.postprocess_cif(result) |
| |
| def batch_decode(self, *args, **kwargs): |
| """批量解码后后处理""" |
| results = super().batch_decode(*args, **kwargs) |
| return [self.postprocess_cif(r) for r in results] |
| |
| def save_pretrained(self, save_directory, *args, **kwargs): |
| """保存tokenizer及CIF配置""" |
| |
| result = super().save_pretrained(save_directory, *args, **kwargs) |
| |
| |
| cif_config_path = os.path.join(save_directory, "cif_config.json") |
| with open(cif_config_path, "w", encoding="utf-8") as f: |
| json.dump({"space_groups": self.space_groups}, f, ensure_ascii=False, indent=2) |
| |
| |
| tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json") |
| if os.path.exists(tokenizer_config_path): |
| with open(tokenizer_config_path, "r", encoding="utf-8") as f: |
| config = json.load(f) |
| |
| config["tokenizer_class"] = "CIFTokenizerFast" |
| config["auto_map"] = { |
| "AutoTokenizer": ["cif_tokenizer_fast.CIFTokenizerFast", None] |
| } |
| |
| with open(tokenizer_config_path, "w", encoding="utf-8") as f: |
| json.dump(config, f, ensure_ascii=False, indent=2) |
| |
| |
| try: |
| import shutil |
| src = os.path.abspath(__file__) |
| dst = os.path.join(save_directory, "cif_tokenizer_fast.py") |
| if src != dst: |
| shutil.copy2(src, dst) |
| print(f"✓ 已复制 cif_tokenizer_fast.py") |
| except Exception as e: |
| print(f"⚠ 复制失败: {e}") |
| |
| return result |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| """从预训练模型加载""" |
| |
| space_groups = None |
| if isinstance(pretrained_model_name_or_path, str): |
| config_path = os.path.join(pretrained_model_name_or_path, "cif_config.json") |
| if os.path.exists(config_path): |
| with open(config_path, "r") as f: |
| config = json.load(f) |
| space_groups = config.get("space_groups") |
| |
| |
| space_groups = kwargs.pop("space_groups", space_groups) |
| |
| |
| tokenizer = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
| tokenizer.space_groups = space_groups if space_groups is not None else SPACE_GROUPS |
| tokenizer._build_spacegroup_pattern() |
| |
| return tokenizer |
|
|
|
|
| def register_cif_tokenizer(): |
| """兼容性函数(不需要实际注册)""" |
| pass |
|
|